sgd.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. # mypy: allow-untyped-defs
  2. r"""Implementation for Stochastic Gradient Descent optimizer."""
  3. from typing import cast
  4. import torch
  5. from torch import Tensor
  6. from .optimizer import (
  7. _default_to_fused_or_foreach,
  8. _device_dtype_check_for_fused,
  9. _differentiable_doc,
  10. _foreach_doc,
  11. _fused_doc,
  12. _maximize_doc,
  13. _params_doc,
  14. _to_scalar,
  15. _use_grad_for_differentiable,
  16. DeviceDict,
  17. Optimizer,
  18. ParamsT,
  19. )
  20. __all__ = ["SGD", "sgd"]
  21. class SGD(Optimizer): # noqa: D101
  22. def __init__(
  23. self,
  24. params: ParamsT,
  25. lr: float | Tensor = 1e-3,
  26. momentum: float = 0,
  27. dampening: float = 0,
  28. weight_decay: float | Tensor = 0,
  29. nesterov: bool = False,
  30. *,
  31. maximize: bool = False,
  32. foreach: bool | None = None,
  33. differentiable: bool = False,
  34. fused: bool | None = None,
  35. ) -> None: # noqa: D107
  36. if isinstance(lr, Tensor) and lr.numel() != 1:
  37. raise ValueError("Tensor lr must be 1-element")
  38. if lr < 0.0:
  39. raise ValueError(f"Invalid learning rate: {lr}")
  40. if momentum < 0.0:
  41. raise ValueError(f"Invalid momentum value: {momentum}")
  42. if weight_decay < 0.0:
  43. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  44. defaults = {
  45. "lr": lr,
  46. "momentum": momentum,
  47. "dampening": dampening,
  48. "weight_decay": weight_decay,
  49. "nesterov": nesterov,
  50. "maximize": maximize,
  51. "foreach": foreach,
  52. "differentiable": differentiable,
  53. "fused": fused,
  54. }
  55. if nesterov and (momentum <= 0 or dampening != 0):
  56. raise ValueError("Nesterov momentum requires a momentum and zero dampening")
  57. super().__init__(params, defaults)
  58. if fused:
  59. self._step_supports_amp_scaling = True
  60. self._need_device_dtype_check_for_fused = True
  61. if differentiable:
  62. raise RuntimeError("`fused` does not support `differentiable`")
  63. if foreach:
  64. raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
  65. def __setstate__(self, state): # noqa: D105
  66. super().__setstate__(state)
  67. for group in self.param_groups:
  68. group.setdefault("nesterov", False)
  69. group.setdefault("maximize", False)
  70. group.setdefault("foreach", None)
  71. group.setdefault("differentiable", False)
  72. group.setdefault("fused", False)
  73. def _init_group(self, group, params, grads, momentum_buffer_list):
  74. has_sparse_grad = False
  75. for p in group["params"]:
  76. if p.grad is not None:
  77. if group["fused"] and getattr(
  78. self, "_need_device_dtype_check_for_fused", True
  79. ):
  80. _device_dtype_check_for_fused(p)
  81. self._need_device_dtype_check_for_fused = False
  82. params.append(p)
  83. grads.append(p.grad)
  84. if p.grad.is_sparse:
  85. has_sparse_grad = True
  86. if group["momentum"] != 0:
  87. state = self.state[p]
  88. momentum_buffer_list.append(state.get("momentum_buffer"))
  89. return has_sparse_grad
  90. @_use_grad_for_differentiable
  91. def step(self, closure=None):
  92. """Perform a single optimization step.
  93. Args:
  94. closure (Callable, optional): A closure that reevaluates the model
  95. and returns the loss.
  96. """
  97. loss = None
  98. if closure is not None:
  99. with torch.enable_grad():
  100. loss = closure()
  101. for group in self.param_groups:
  102. params: list[Tensor] = []
  103. grads: list[Tensor] = []
  104. momentum_buffer_list: list[Tensor | None] = []
  105. has_sparse_grad = self._init_group(
  106. group, params, grads, momentum_buffer_list
  107. )
  108. sgd(
  109. params,
  110. grads,
  111. momentum_buffer_list,
  112. weight_decay=group["weight_decay"],
  113. momentum=group["momentum"],
  114. lr=group["lr"],
  115. dampening=group["dampening"],
  116. nesterov=group["nesterov"],
  117. maximize=group["maximize"],
  118. has_sparse_grad=has_sparse_grad,
  119. foreach=group["foreach"],
  120. fused=group["fused"],
  121. grad_scale=getattr(self, "grad_scale", None),
  122. found_inf=getattr(self, "found_inf", None),
  123. )
  124. if group["momentum"] != 0:
  125. # update momentum_buffers in state
  126. for p, momentum_buffer in zip(
  127. params, momentum_buffer_list, strict=True
  128. ):
  129. state = self.state[p]
  130. state["momentum_buffer"] = momentum_buffer
  131. return loss
  132. SGD.__doc__ = (
  133. r"""Implements stochastic gradient descent (optionally with momentum).
  134. .. math::
  135. \begin{aligned}
  136. &\rule{110mm}{0.4pt} \\
  137. &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
  138. \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
  139. &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
  140. \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex]
  141. &\rule{110mm}{0.4pt} \\
  142. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  143. &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
  144. &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
  145. &\hspace{5mm}\textbf{else} \\
  146. &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  147. &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
  148. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  149. &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\
  150. &\hspace{10mm}\textbf{if} \: t > 1 \\
  151. &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\
  152. &\hspace{10mm}\textbf{else} \\
  153. &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\
  154. &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\
  155. &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\
  156. &\hspace{10mm}\textbf{else} \\[-1.ex]
  157. &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\
  158. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex]
  159. &\rule{110mm}{0.4pt} \\[-1.ex]
  160. &\bf{return} \: \theta_t \\[-1.ex]
  161. &\rule{110mm}{0.4pt} \\[-1.ex]
  162. \end{aligned}
  163. Nesterov momentum is based on the formula from
  164. `On the importance of initialization and momentum in deep learning`__.
  165. """
  166. + rf"""
  167. Args:
  168. {_params_doc}
  169. lr (float, Tensor, optional): learning rate (default: 1e-3)
  170. momentum (float, optional): momentum factor (default: 0)
  171. dampening (float, optional): dampening for momentum (default: 0)
  172. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  173. nesterov (bool, optional): enables Nesterov momentum. Only applicable
  174. when momentum is non-zero. (default: False)
  175. {_maximize_doc}
  176. {_foreach_doc}
  177. {_differentiable_doc}
  178. {_fused_doc}
  179. """
  180. + r"""
  181. Example:
  182. >>> # xdoctest: +SKIP
  183. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  184. >>> optimizer.zero_grad()
  185. >>> loss_fn(model(input), target).backward()
  186. >>> optimizer.step()
  187. __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
  188. .. note::
  189. The implementation of SGD with Momentum/Nesterov subtly differs from
  190. Sutskever et al. and implementations in some other frameworks.
  191. Considering the specific case of Momentum, the update can be written as
  192. .. math::
  193. \begin{aligned}
  194. v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
  195. p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
  196. \end{aligned}
  197. where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
  198. parameters, gradient, velocity, and momentum respectively.
  199. This is in contrast to Sutskever et al. and
  200. other frameworks which employ an update of the form
  201. .. math::
  202. \begin{aligned}
  203. v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
  204. p_{t+1} & = p_{t} - v_{t+1}.
  205. \end{aligned}
  206. The Nesterov version is analogously modified.
  207. Moreover, the initial value of the momentum buffer is set to the
  208. gradient value at the first step. This is in contrast to some other
  209. frameworks that initialize it to all zeros. One notable side effect
  210. of this decision is that the first momentum value will not be scaled
  211. by dampening. Dampening will be applied starting at the second step.
  212. """
  213. )
  214. def sgd(
  215. params: list[Tensor],
  216. d_p_list: list[Tensor],
  217. momentum_buffer_list: list[Tensor | None],
  218. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  219. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  220. has_sparse_grad: bool = False,
  221. foreach: bool | None = None,
  222. fused: bool | None = None,
  223. grad_scale: Tensor | None = None,
  224. found_inf: Tensor | None = None,
  225. *,
  226. weight_decay: float,
  227. momentum: float,
  228. lr: float,
  229. dampening: float,
  230. nesterov: bool,
  231. maximize: bool,
  232. ) -> None:
  233. r"""Functional API that performs SGD algorithm computation.
  234. See :class:`~torch.optim.SGD` for details.
  235. """
  236. # Respect when the user inputs False/True for foreach or fused. We only want to change
  237. # the default when neither have been user-specified. Note that we default to foreach
  238. # and pass False to use_fused. This is not a mistake--we want to give the fused impl
  239. # bake-in time before making it the default, even if it is typically faster.
  240. if foreach is None and fused is None:
  241. # why must we be explicit about an if statement for torch.jit.is_scripting here?
  242. # because JIT can't handle Optionals nor fancy conditionals when scripting
  243. if not torch.jit.is_scripting():
  244. fused, foreach = _default_to_fused_or_foreach(
  245. params, differentiable=False, use_fused=False
  246. )
  247. else:
  248. foreach = False
  249. fused = False
  250. if foreach is None:
  251. foreach = False
  252. if fused is None:
  253. fused = False
  254. if foreach and torch.jit.is_scripting():
  255. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  256. if fused and torch.jit.is_scripting():
  257. raise RuntimeError("torch.jit.script not supported with fused optimizers")
  258. if foreach and not torch.jit.is_scripting():
  259. func = _multi_tensor_sgd
  260. elif fused and not torch.jit.is_scripting():
  261. func = _fused_sgd
  262. else:
  263. func = _single_tensor_sgd
  264. func(
  265. params,
  266. d_p_list,
  267. momentum_buffer_list,
  268. weight_decay=weight_decay,
  269. momentum=momentum,
  270. lr=lr,
  271. dampening=dampening,
  272. nesterov=nesterov,
  273. has_sparse_grad=has_sparse_grad,
  274. maximize=maximize,
  275. grad_scale=grad_scale,
  276. found_inf=found_inf,
  277. )
  278. def _single_tensor_sgd(
  279. params: list[Tensor],
  280. grads: list[Tensor],
  281. momentum_buffer_list: list[Tensor | None],
  282. grad_scale: Tensor | None,
  283. found_inf: Tensor | None,
  284. *,
  285. weight_decay: float,
  286. momentum: float,
  287. lr: float,
  288. dampening: float,
  289. nesterov: bool,
  290. maximize: bool,
  291. has_sparse_grad: bool,
  292. ) -> None:
  293. if grad_scale is not None or found_inf is not None:
  294. raise AssertionError("Expected grad_scale and found_inf to be None")
  295. if not torch.jit.is_scripting():
  296. lr = _to_scalar(lr)
  297. for i, param in enumerate(params):
  298. grad = grads[i] if not maximize else -grads[i]
  299. if weight_decay != 0:
  300. # Nested if is necessary to bypass jitscript rules
  301. if isinstance(weight_decay, Tensor):
  302. if weight_decay.requires_grad:
  303. # usually this is the differentiable path, which is why the param.clone() is needed
  304. grad = grad.addcmul_(param.clone(), weight_decay)
  305. else:
  306. # pyrefly: ignore [bad-argument-type]
  307. grad = grad.add(param, alpha=weight_decay)
  308. else:
  309. grad = grad.add(param, alpha=weight_decay)
  310. if momentum != 0:
  311. buf = momentum_buffer_list[i]
  312. if buf is None:
  313. buf = grad.detach().clone()
  314. momentum_buffer_list[i] = buf
  315. else:
  316. buf.mul_(momentum).add_(grad, alpha=1 - dampening)
  317. if nesterov:
  318. grad = grad.add(buf, alpha=momentum)
  319. else:
  320. grad = buf
  321. # Nested if is necessary to bypass jitscript rules
  322. if isinstance(lr, Tensor):
  323. if lr.requires_grad:
  324. param.addcmul_(grad, lr, value=-1)
  325. else:
  326. # pyrefly: ignore [bad-argument-type]
  327. param.add_(grad, alpha=-lr)
  328. else:
  329. param.add_(grad, alpha=-lr)
  330. def _multi_tensor_sgd(
  331. params: list[Tensor],
  332. grads: list[Tensor],
  333. momentum_buffer_list: list[Tensor | None],
  334. grad_scale: Tensor | None,
  335. found_inf: Tensor | None,
  336. *,
  337. weight_decay: float,
  338. momentum: float,
  339. lr: float,
  340. dampening: float,
  341. nesterov: bool,
  342. maximize: bool,
  343. has_sparse_grad: bool,
  344. ) -> None:
  345. if grad_scale is not None or found_inf is not None:
  346. raise AssertionError("Expected grad_scale and found_inf to be None")
  347. if len(params) == 0:
  348. return
  349. lr = _to_scalar(lr)
  350. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  351. [params, grads, momentum_buffer_list], # type: ignore[list-item]
  352. with_indices=True,
  353. )
  354. for (
  355. device_params_,
  356. device_grads_,
  357. device_momentum_buffer_list,
  358. ), indices in grouped_tensors.values():
  359. device_params: list[Tensor] = cast(list[Tensor], device_params_)
  360. device_grads: list[Tensor] = cast(list[Tensor], device_grads_)
  361. device_has_sparse_grad = has_sparse_grad and any(
  362. grad.is_sparse for grad in device_grads
  363. )
  364. if maximize:
  365. device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
  366. if weight_decay != 0:
  367. # Reuse the intermediate memory (device_grads) already allocated for maximize
  368. if maximize:
  369. torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
  370. else:
  371. device_grads = torch._foreach_add( # type: ignore[assignment]
  372. device_grads, device_params, alpha=weight_decay
  373. )
  374. if momentum != 0:
  375. bufs: list[Tensor] = []
  376. all_states_with_momentum_buffer = True
  377. for i in range(len(device_momentum_buffer_list)):
  378. if device_momentum_buffer_list[i] is None:
  379. all_states_with_momentum_buffer = False
  380. break
  381. else:
  382. bufs.append(cast(Tensor, device_momentum_buffer_list[i]))
  383. if all_states_with_momentum_buffer:
  384. torch._foreach_mul_(bufs, momentum)
  385. torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
  386. else:
  387. bufs = []
  388. for i in range(len(device_momentum_buffer_list)):
  389. if device_momentum_buffer_list[i] is None:
  390. buf = device_momentum_buffer_list[i] = momentum_buffer_list[
  391. indices[i]
  392. ] = device_grads[i].detach().clone()
  393. else:
  394. buf = cast(Tensor, device_momentum_buffer_list[i])
  395. buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
  396. bufs.append(buf)
  397. if nesterov:
  398. torch._foreach_add_(device_grads, bufs, alpha=momentum)
  399. else:
  400. device_grads = bufs
  401. if not device_has_sparse_grad:
  402. # handle internal item() call if lr is a tensor
  403. if isinstance(lr, torch.Tensor) and torch.compiler.is_compiling():
  404. grads_x_lr = torch._foreach_mul(device_grads, -lr)
  405. torch._foreach_add_(device_params, grads_x_lr)
  406. else:
  407. torch._foreach_add_(device_params, device_grads, alpha=-lr)
  408. else:
  409. # foreach APIs don't support sparse
  410. for i in range(len(device_params)):
  411. device_params[i].add_(device_grads[i], alpha=-lr)
  412. def _fused_sgd(
  413. params: list[Tensor],
  414. grads: list[Tensor],
  415. momentum_buffer_list: list[Tensor | None],
  416. grad_scale: Tensor | None,
  417. found_inf: Tensor | None,
  418. *,
  419. weight_decay: float,
  420. momentum: float,
  421. lr: float,
  422. dampening: float,
  423. nesterov: bool,
  424. maximize: bool,
  425. has_sparse_grad: bool,
  426. ) -> None:
  427. if not params:
  428. return
  429. if has_sparse_grad:
  430. raise RuntimeError("`_fused_sgd` does not support sparse gradients")
  431. grad_scale_dict: DeviceDict = (
  432. {grad_scale.device: grad_scale} if grad_scale is not None else {}
  433. )
  434. found_inf_dict: DeviceDict = (
  435. {found_inf.device: found_inf} if found_inf is not None else {}
  436. )
  437. no_momentum_buffer = momentum == 0
  438. is_first_step = (
  439. all(t is None for t in momentum_buffer_list) and not no_momentum_buffer
  440. )
  441. if is_first_step:
  442. for i, g in enumerate(grads):
  443. momentum_buffer_list[i] = torch.empty_like(g)
  444. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  445. [params, grads, momentum_buffer_list], # type: ignore[list-item]
  446. with_indices=False,
  447. )
  448. for (device, _), (
  449. (device_params_, device_grads_, device_momentum_buffer_list),
  450. _,
  451. ) in grouped_tensors.items():
  452. device_params: list[Tensor] = cast(list[Tensor], device_params_)
  453. device_grads: list[Tensor] = cast(list[Tensor], device_grads_)
  454. device_grad_scale, device_found_inf = None, None
  455. if grad_scale is not None:
  456. device_grad_scale = grad_scale_dict.setdefault(
  457. device, grad_scale.to(device)
  458. )
  459. if found_inf_dict is not None and found_inf is not None:
  460. device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device))
  461. torch._fused_sgd_(
  462. device_params,
  463. device_grads,
  464. []
  465. if no_momentum_buffer
  466. else cast(list[Tensor], device_momentum_buffer_list),
  467. weight_decay=weight_decay,
  468. momentum=momentum,
  469. lr=lr,
  470. dampening=dampening,
  471. nesterov=nesterov,
  472. maximize=maximize,
  473. is_first_step=is_first_step,
  474. grad_scale=device_grad_scale,
  475. found_inf=device_found_inf,
  476. )