asgd.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. # mypy: allow-untyped-defs
  2. from typing import cast
  3. import torch
  4. from torch import Tensor
  5. from .optimizer import (
  6. _capturable_doc,
  7. _default_to_fused_or_foreach,
  8. _differentiable_doc,
  9. _disable_dynamo_if_unsupported,
  10. _foreach_doc,
  11. _get_capturable_supported_devices,
  12. _get_scalar_dtype,
  13. _get_value,
  14. _maximize_doc,
  15. _params_doc,
  16. _to_scalar,
  17. _use_grad_for_differentiable,
  18. _view_as_real,
  19. Optimizer,
  20. ParamsT,
  21. )
  22. __all__ = ["ASGD", "asgd"]
  23. class ASGD(Optimizer):
  24. def __init__(
  25. self,
  26. params: ParamsT,
  27. lr: float | Tensor = 1e-2,
  28. lambd: float = 1e-4,
  29. alpha: float = 0.75,
  30. t0: float = 1e6,
  31. weight_decay: float = 0,
  32. foreach: bool | None = None,
  33. maximize: bool = False,
  34. differentiable: bool = False,
  35. capturable: bool = False,
  36. ) -> None:
  37. if isinstance(lr, Tensor) and lr.numel() != 1:
  38. raise ValueError("Tensor lr must be 1-element")
  39. if not 0.0 <= lr:
  40. raise ValueError(f"Invalid learning rate: {lr}")
  41. if not 0.0 <= weight_decay:
  42. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  43. defaults = {
  44. "lr": lr,
  45. "lambd": lambd,
  46. "alpha": alpha,
  47. "t0": t0,
  48. "weight_decay": weight_decay,
  49. "foreach": foreach,
  50. "maximize": maximize,
  51. "differentiable": differentiable,
  52. "capturable": capturable,
  53. }
  54. super().__init__(params, defaults)
  55. def __setstate__(self, state):
  56. super().__setstate__(state)
  57. for group in self.param_groups:
  58. group.setdefault("foreach", None)
  59. group.setdefault("maximize", False)
  60. group.setdefault("differentiable", False)
  61. group.setdefault("capturable", False)
  62. for p in group["params"]:
  63. p_state = self.state.get(p, [])
  64. if len(p_state) != 0:
  65. if not torch.is_tensor(p_state["step"]):
  66. step_val = float(p_state["step"])
  67. p_state["step"] = torch.tensor(
  68. step_val, dtype=_get_scalar_dtype(), device=p.device
  69. )
  70. if not torch.is_tensor(p_state["eta"]):
  71. p_state["eta"] = torch.tensor(
  72. p_state["eta"], dtype=_get_scalar_dtype(), device=p.device
  73. )
  74. if not torch.is_tensor(p_state["mu"]):
  75. p_state["mu"] = torch.tensor(
  76. p_state["mu"], dtype=_get_scalar_dtype(), device=p.device
  77. )
  78. def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
  79. has_complex = False
  80. for p in group["params"]:
  81. if p.grad is not None:
  82. has_complex |= torch.is_complex(p)
  83. params_with_grad.append(p)
  84. if p.grad.is_sparse:
  85. raise RuntimeError("ASGD does not support sparse gradients")
  86. grads.append(p.grad)
  87. state = self.state[p]
  88. # State initialization
  89. if len(state) == 0:
  90. state["step"] = torch.zeros(
  91. (), device=p.device, dtype=_get_scalar_dtype()
  92. )
  93. state["eta"] = (
  94. torch.as_tensor(
  95. _to_scalar(group["lr"]),
  96. device=p.device,
  97. dtype=_get_scalar_dtype(),
  98. )
  99. .clone()
  100. .detach()
  101. )
  102. state["mu"] = torch.ones(
  103. (), device=p.device, dtype=_get_scalar_dtype()
  104. )
  105. state["ax"] = torch.zeros_like(
  106. p, memory_format=torch.preserve_format
  107. )
  108. mus.append(state["mu"])
  109. axs.append(state["ax"])
  110. etas.append(state["eta"])
  111. state_steps.append(state["step"])
  112. return has_complex
  113. @_use_grad_for_differentiable
  114. def step(self, closure=None):
  115. """Perform a single optimization step.
  116. Args:
  117. closure (Callable, optional): A closure that reevaluates the model
  118. and returns the loss.
  119. """
  120. self._accelerator_graph_capture_health_check()
  121. loss = None
  122. if closure is not None:
  123. with torch.enable_grad():
  124. loss = closure()
  125. for group in self.param_groups:
  126. params_with_grad: list[Tensor] = []
  127. grads: list[Tensor] = []
  128. mus: list[Tensor] = []
  129. axs: list[Tensor] = []
  130. etas: list[Tensor] = []
  131. state_steps: list[Tensor] = []
  132. has_complex = self._init_group(
  133. group, params_with_grad, grads, mus, axs, etas, state_steps
  134. )
  135. asgd(
  136. params_with_grad,
  137. grads,
  138. axs,
  139. mus,
  140. etas,
  141. state_steps,
  142. lambd=group["lambd"],
  143. lr=group["lr"],
  144. t0=group["t0"],
  145. alpha=group["alpha"],
  146. weight_decay=group["weight_decay"],
  147. foreach=group["foreach"],
  148. maximize=group["maximize"],
  149. differentiable=group["differentiable"],
  150. capturable=group["capturable"],
  151. has_complex=has_complex,
  152. )
  153. return loss
  154. ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent.
  155. It has been proposed in `Acceleration of stochastic approximation by
  156. averaging`_.
  157. Args:
  158. {_params_doc}
  159. lr (float, Tensor, optional): learning rate (default: 1e-2)
  160. lambd (float, optional): decay term (default: 1e-4)
  161. alpha (float, optional): power for eta update (default: 0.75)
  162. t0 (float, optional): point at which to start averaging (default: 1e6)
  163. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  164. {_foreach_doc}
  165. {_maximize_doc}
  166. {_differentiable_doc}
  167. {_capturable_doc}
  168. .. _Acceleration of stochastic approximation by averaging:
  169. https://meyn.ece.ufl.edu/wp-content/uploads/sites/77/archive/spm_files/Courses/ECE555-2011/555media/poljud92.pdf
  170. """
  171. def _single_tensor_asgd(
  172. params: list[Tensor],
  173. grads: list[Tensor],
  174. axs: list[Tensor],
  175. mus: list[Tensor],
  176. etas: list[Tensor],
  177. state_steps: list[Tensor],
  178. *,
  179. lambd: float,
  180. lr: float,
  181. t0: float,
  182. alpha: float,
  183. weight_decay: float,
  184. maximize: bool,
  185. differentiable: bool,
  186. capturable: bool,
  187. has_complex: bool,
  188. ) -> None:
  189. if not torch.jit.is_scripting():
  190. lr = _to_scalar(lr)
  191. for i, param in enumerate(params):
  192. grad = grads[i]
  193. grad = grad if not maximize else -grad
  194. mu = mus[i]
  195. ax = axs[i]
  196. eta = etas[i]
  197. step_t = state_steps[i]
  198. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  199. if not torch.compiler.is_compiling() and capturable:
  200. capturable_supported_devices = _get_capturable_supported_devices()
  201. if not (
  202. param.device.type
  203. == mu.device.type
  204. == eta.device.type
  205. == step_t.device.type
  206. and param.device.type in capturable_supported_devices
  207. ):
  208. raise AssertionError(
  209. f"If capturable=True, params, mus, etas, and state_steps must be "
  210. f"on supported devices: {capturable_supported_devices}."
  211. )
  212. if torch.is_complex(param):
  213. grad = torch.view_as_real(grad)
  214. param = torch.view_as_real(param)
  215. ax = torch.view_as_real(ax)
  216. # update step
  217. step_t += 1
  218. if weight_decay != 0:
  219. grad = grad.add(param, alpha=weight_decay)
  220. if capturable:
  221. param.mul_(1 - lambd * eta)
  222. param.addcmul_(grad, eta, value=-1) # update parameter
  223. else:
  224. eta_value = _get_value(eta)
  225. param.mul_(1 - lambd * eta_value) # decay term
  226. param.add_(grad, alpha=-eta_value) # update parameter
  227. # averaging
  228. if capturable or mu.item() != 1:
  229. ax.add_(param.sub(ax).mul_(mu))
  230. else:
  231. ax.copy_(param)
  232. if capturable:
  233. # pyrefly: ignore [unsupported-operation]
  234. eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
  235. mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
  236. else:
  237. step = _get_value(step_t)
  238. new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha))
  239. eta.copy_(new_eta)
  240. new_mu = torch.as_tensor(1 / max(1, step - t0))
  241. mu.copy_(new_mu)
  242. def _multi_tensor_asgd(
  243. params: list[Tensor],
  244. grads: list[Tensor],
  245. axs: list[Tensor],
  246. mus: list[Tensor],
  247. etas: list[Tensor],
  248. state_steps: list[Tensor],
  249. *,
  250. lambd: float,
  251. lr: float,
  252. t0: float,
  253. alpha: float,
  254. weight_decay: float,
  255. maximize: bool,
  256. differentiable: bool,
  257. capturable: bool,
  258. has_complex: bool,
  259. ) -> None:
  260. if len(params) == 0:
  261. return
  262. if differentiable:
  263. raise AssertionError("_foreach ops don't support autograd")
  264. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  265. if not torch.compiler.is_compiling() and capturable:
  266. capturable_supported_devices = _get_capturable_supported_devices(
  267. supports_xla=False
  268. )
  269. if not all(
  270. p.device.type == mu.device.type == eta.device.type == step.device.type
  271. and p.device.type in capturable_supported_devices
  272. for p, mu, eta, step in zip(params, mus, etas, state_steps, strict=True)
  273. ):
  274. raise AssertionError(
  275. f"If capturable=True, params, mus, etas, and state_steps must be on "
  276. f"supported devices: {capturable_supported_devices}."
  277. )
  278. lr = _to_scalar(lr)
  279. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  280. [params, grads, axs, mus, etas, state_steps] # type: ignore[list-item]
  281. )
  282. for (device, _), (
  283. (
  284. grouped_params_,
  285. grouped_grads_,
  286. grouped_axs_,
  287. grouped_mus_,
  288. grouped_etas_,
  289. grouped_state_steps_,
  290. ),
  291. _,
  292. ) in grouped_tensors.items():
  293. grouped_params = cast(list[Tensor], grouped_params_)
  294. grouped_grads = cast(list[Tensor], grouped_grads_)
  295. grouped_axs = cast(list[Tensor], grouped_axs_)
  296. grouped_mus = cast(list[Tensor], grouped_mus_)
  297. grouped_etas = cast(list[Tensor], grouped_etas_)
  298. grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
  299. if has_complex:
  300. _view_as_real(grouped_params, grouped_grads, grouped_axs)
  301. if maximize:
  302. grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
  303. # Update steps
  304. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  305. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  306. # wrapped it once now. The alpha is required to assure we go to the right overload.
  307. if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
  308. torch._foreach_add_(
  309. grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  310. )
  311. else:
  312. torch._foreach_add_(grouped_state_steps, 1)
  313. # intermediate = grad + param * lambd
  314. intermediate: tuple[Tensor, ...] | list[Tensor]
  315. if weight_decay != 0:
  316. if maximize:
  317. torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
  318. intermediate = grouped_grads
  319. else:
  320. intermediate = torch._foreach_add(
  321. grouped_grads, grouped_params, alpha=weight_decay
  322. )
  323. torch._foreach_add_(intermediate, grouped_params, alpha=lambd)
  324. else:
  325. intermediate = torch._foreach_add(
  326. grouped_grads, grouped_params, alpha=lambd
  327. )
  328. # update param
  329. # param * (1 - lambd * eta) - eta * grad
  330. # => param - param * lambd * eta - eta * grad
  331. # => param - eta * intermediate
  332. torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1)
  333. del intermediate
  334. # update grouped_axs
  335. # averaging: ax = ax + mu * (param - ax)
  336. # Note (mlazos): We can't use lerp here since it requires weight to be float64
  337. # and our grouping code requires dtypes to match for all tensors in a group (and it should, since
  338. # we use the mus in other places)
  339. # all dtypes need to match, so we could introduce a cast in a loop
  340. # but since this only adds one additional kernel launch, this looks like the cleaner
  341. # and faster solution
  342. intermediate = torch._foreach_sub(grouped_params, grouped_axs)
  343. torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus)
  344. del intermediate
  345. new_etas: tuple[Tensor, ...] | list[Tensor]
  346. new_mus: tuple[Tensor, ...] | list[Tensor]
  347. if capturable:
  348. # update grouped_mus
  349. new_mus = torch._foreach_sub(grouped_state_steps, t0)
  350. torch._foreach_maximum_(new_mus, 1.0)
  351. torch._foreach_reciprocal_(new_mus)
  352. torch._foreach_copy_(grouped_mus, new_mus)
  353. del new_mus
  354. # update eta = lr / ((1 + lambd * lr * step)^alpha)
  355. new_etas = torch._foreach_mul(grouped_state_steps, lambd)
  356. torch._foreach_mul_(new_etas, lr)
  357. torch._foreach_add_(new_etas, 1)
  358. torch._foreach_pow_(new_etas, alpha)
  359. torch._foreach_reciprocal_(new_etas)
  360. torch._foreach_mul_(new_etas, lr)
  361. torch._foreach_copy_(grouped_etas, new_etas)
  362. else:
  363. new_etas = [
  364. torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device)
  365. for step in grouped_state_steps
  366. ]
  367. new_mus = [
  368. torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device)
  369. for step in grouped_state_steps
  370. ]
  371. torch._foreach_copy_(grouped_etas, new_etas)
  372. torch._foreach_copy_(grouped_mus, new_mus)
  373. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd)
  374. def asgd(
  375. params: list[Tensor],
  376. grads: list[Tensor],
  377. axs: list[Tensor],
  378. mus: list[Tensor],
  379. etas: list[Tensor],
  380. state_steps: list[Tensor],
  381. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  382. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  383. foreach: bool | None = None,
  384. maximize: bool = False,
  385. differentiable: bool = False,
  386. capturable: bool = False,
  387. has_complex: bool = False,
  388. *,
  389. lambd: float,
  390. lr: float,
  391. t0: float,
  392. alpha: float,
  393. weight_decay: float,
  394. ) -> None:
  395. r"""Functional API that performs asgd algorithm computation.
  396. See :class:`~torch.optim.ASGD` for details.
  397. """
  398. if foreach is None:
  399. _, foreach = _default_to_fused_or_foreach(
  400. params, differentiable, use_fused=False
  401. )
  402. if foreach and torch.jit.is_scripting():
  403. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  404. if foreach and not torch.jit.is_scripting():
  405. func = _multi_tensor_asgd
  406. else:
  407. func = _single_tensor_asgd
  408. func(
  409. params,
  410. grads,
  411. axs,
  412. mus,
  413. etas,
  414. state_steps,
  415. lambd=lambd,
  416. lr=lr,
  417. t0=t0,
  418. alpha=alpha,
  419. weight_decay=weight_decay,
  420. maximize=maximize,
  421. differentiable=differentiable,
  422. capturable=capturable,
  423. has_complex=has_complex,
  424. )