adadelta.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. # mypy: allow-untyped-defs
  2. from typing import Any, cast
  3. import torch
  4. from torch import Tensor
  5. from .optimizer import (
  6. _capturable_doc,
  7. _default_to_fused_or_foreach,
  8. _differentiable_doc,
  9. _disable_dynamo_if_unsupported,
  10. _foreach_doc,
  11. _get_capturable_supported_devices,
  12. _get_scalar_dtype,
  13. _maximize_doc,
  14. _params_doc,
  15. _to_scalar,
  16. _use_grad_for_differentiable,
  17. _view_as_real,
  18. Optimizer,
  19. ParamsT,
  20. )
  21. __all__ = ["Adadelta", "adadelta"]
  22. class Adadelta(Optimizer):
  23. def __init__(
  24. self,
  25. params: ParamsT,
  26. lr: float | Tensor = 1.0,
  27. rho: float = 0.9,
  28. eps: float = 1e-6,
  29. weight_decay: float = 0,
  30. foreach: bool | None = None,
  31. *,
  32. capturable: bool = False,
  33. maximize: bool = False,
  34. differentiable: bool = False,
  35. ) -> None:
  36. if isinstance(lr, Tensor) and lr.numel() != 1:
  37. raise ValueError("Tensor lr must be 1-element")
  38. if not 0.0 <= lr:
  39. raise ValueError(f"Invalid learning rate: {lr}")
  40. if not 0.0 <= rho <= 1.0:
  41. raise ValueError(f"Invalid rho value: {rho}")
  42. if not 0.0 <= eps:
  43. raise ValueError(f"Invalid epsilon value: {eps}")
  44. if not 0.0 <= weight_decay:
  45. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  46. defaults = {
  47. "lr": lr,
  48. "rho": rho,
  49. "eps": eps,
  50. "weight_decay": weight_decay,
  51. "maximize": maximize,
  52. "capturable": capturable,
  53. "foreach": foreach,
  54. "differentiable": differentiable,
  55. }
  56. super().__init__(params, defaults)
  57. def __setstate__(self, state):
  58. super().__setstate__(state)
  59. for group in self.param_groups:
  60. group.setdefault("foreach", None)
  61. group.setdefault("maximize", False)
  62. group.setdefault("differentiable", False)
  63. group.setdefault("capturable", False)
  64. for p in group["params"]:
  65. p_state = self.state.get(p, [])
  66. if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
  67. step_val = float(p_state["step"])
  68. p_state["step"] = (
  69. torch.tensor(
  70. step_val, dtype=_get_scalar_dtype(), device=p.device
  71. )
  72. if group["capturable"]
  73. else torch.tensor(step_val, dtype=_get_scalar_dtype())
  74. )
  75. def _init_group(
  76. self,
  77. group: dict[str, Any],
  78. params_with_grad: list[Tensor],
  79. grads: list[Tensor],
  80. square_avgs: list[Tensor],
  81. acc_deltas: list[Tensor],
  82. state_steps: list[Tensor],
  83. ):
  84. has_complex = False
  85. p: Tensor
  86. for p in group["params"]:
  87. if p.grad is None:
  88. continue
  89. has_complex |= torch.is_complex(p)
  90. params_with_grad.append(p)
  91. if p.grad.is_sparse:
  92. raise RuntimeError("Adadelta does not support sparse gradients")
  93. grads.append(p.grad)
  94. state = self.state[p]
  95. # Lazy state initialization
  96. if len(state) == 0:
  97. state["step"] = (
  98. torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
  99. if group["capturable"]
  100. else torch.zeros((), dtype=_get_scalar_dtype())
  101. )
  102. state["square_avg"] = torch.zeros_like(
  103. p, memory_format=torch.preserve_format
  104. )
  105. state["acc_delta"] = torch.zeros_like(
  106. p, memory_format=torch.preserve_format
  107. )
  108. square_avgs.append(state["square_avg"])
  109. acc_deltas.append(state["acc_delta"])
  110. state_steps.append(state["step"])
  111. return has_complex
  112. @_use_grad_for_differentiable
  113. def step(self, closure=None):
  114. """Perform a single optimization step.
  115. Args:
  116. closure (Callable, optional): A closure that reevaluates the model
  117. and returns the loss.
  118. """
  119. self._accelerator_graph_capture_health_check()
  120. loss = None
  121. if closure is not None:
  122. with torch.enable_grad():
  123. loss = closure()
  124. for group in self.param_groups:
  125. params_with_grad: list[Tensor] = []
  126. grads: list[Tensor] = []
  127. square_avgs: list[Tensor] = []
  128. acc_deltas: list[Tensor] = []
  129. state_steps: list[Tensor] = []
  130. (
  131. lr,
  132. rho,
  133. eps,
  134. weight_decay,
  135. foreach,
  136. maximize,
  137. differentiable,
  138. capturable,
  139. ) = (
  140. group["lr"],
  141. group["rho"],
  142. group["eps"],
  143. group["weight_decay"],
  144. group["foreach"],
  145. group["maximize"],
  146. group["differentiable"],
  147. group["capturable"],
  148. )
  149. has_complex = self._init_group(
  150. group, params_with_grad, grads, square_avgs, acc_deltas, state_steps
  151. )
  152. adadelta(
  153. params_with_grad,
  154. grads,
  155. square_avgs,
  156. acc_deltas,
  157. state_steps,
  158. lr=lr,
  159. rho=rho,
  160. eps=eps,
  161. weight_decay=weight_decay,
  162. foreach=foreach,
  163. maximize=maximize,
  164. differentiable=differentiable,
  165. capturable=capturable,
  166. has_complex=has_complex,
  167. )
  168. return loss
  169. Adadelta.__doc__ = (
  170. r"""Implements Adadelta algorithm.
  171. .. math::
  172. \begin{aligned}
  173. &\rule{110mm}{0.4pt} \\
  174. &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)},
  175. \: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)},
  176. \: \lambda \text{ (weight decay)} \\
  177. &\textbf{initialize} : v_0 \leftarrow 0 \: \text{ (square avg)},
  178. \: u_0 \leftarrow 0 \: \text{ (accumulate variables)} \\[-1.ex]
  179. &\rule{110mm}{0.4pt} \\
  180. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  181. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  182. &\hspace{5mm}if \: \lambda \neq 0 \\
  183. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  184. &\hspace{5mm} v_t \leftarrow v_{t-1} \rho + g^2_t (1 - \rho) \\
  185. &\hspace{5mm}\Delta x_t \leftarrow \frac{\sqrt{u_{t-1} +
  186. \epsilon }}{ \sqrt{v_t + \epsilon} }g_t \hspace{21mm} \\
  187. &\hspace{5mm} u_t \leftarrow u_{t-1} \rho +
  188. \Delta x^2_t (1 - \rho) \\
  189. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \Delta x_t \\
  190. &\rule{110mm}{0.4pt} \\[-1.ex]
  191. &\bf{return} \: \theta_t \\[-1.ex]
  192. &\rule{110mm}{0.4pt} \\[-1.ex]
  193. \end{aligned}
  194. For further details regarding the algorithm we refer to `ADADELTA: An Adaptive Learning Rate Method`_.
  195. """
  196. + rf"""
  197. Args:
  198. {_params_doc}
  199. lr (float, Tensor, optional): coefficient that scale delta before it is applied
  200. to the parameters (default: 1.0)
  201. rho (float, optional): coefficient used for computing a running average
  202. of squared gradients (default: 0.9). A higher value of `rho` will
  203. result in a slower average, which can be helpful for preventing
  204. oscillations in the learning process.
  205. eps (float, optional): term added to the denominator to improve
  206. numerical stability (default: 1e-6).
  207. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  208. {_foreach_doc}
  209. {_capturable_doc}
  210. {_maximize_doc}
  211. {_differentiable_doc}
  212. .. _ADADELTA\: An Adaptive Learning Rate Method:
  213. https://arxiv.org/abs/1212.5701
  214. """
  215. )
  216. def _single_tensor_adadelta(
  217. params: list[Tensor],
  218. grads: list[Tensor],
  219. square_avgs: list[Tensor],
  220. acc_deltas: list[Tensor],
  221. state_steps: list[Tensor],
  222. *,
  223. lr: float,
  224. rho: float,
  225. eps: float,
  226. weight_decay: float,
  227. maximize: bool,
  228. differentiable: bool,
  229. capturable: bool,
  230. has_complex: bool,
  231. ) -> None:
  232. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  233. if not torch.compiler.is_compiling() and capturable:
  234. capturable_supported_devices = _get_capturable_supported_devices(
  235. supports_xla=False
  236. )
  237. if not all(
  238. p.device.type == step.device.type
  239. and p.device.type in capturable_supported_devices
  240. for p, step in zip(params, state_steps, strict=True)
  241. ):
  242. raise AssertionError(
  243. f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  244. )
  245. if not torch.jit.is_scripting():
  246. lr = _to_scalar(lr)
  247. for param, grad, square_avg, acc_delta, step in zip(
  248. params, grads, square_avgs, acc_deltas, state_steps, strict=True
  249. ):
  250. step += 1
  251. grad = grad if not maximize else -grad
  252. if weight_decay != 0:
  253. grad = grad.add(param, alpha=weight_decay)
  254. if torch.is_complex(param):
  255. square_avg = torch.view_as_real(square_avg)
  256. acc_delta = torch.view_as_real(acc_delta)
  257. grad = torch.view_as_real(grad)
  258. square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)
  259. std = square_avg.add(eps).sqrt_()
  260. delta = acc_delta.add(eps).sqrt_()
  261. if differentiable:
  262. delta = delta.clone()
  263. delta.div_(std).mul_(grad)
  264. acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)
  265. if torch.is_complex(param):
  266. delta = torch.view_as_complex(delta)
  267. param.add_(delta, alpha=-lr)
  268. def _multi_tensor_adadelta(
  269. params: list[Tensor],
  270. grads: list[Tensor],
  271. square_avgs: list[Tensor],
  272. acc_deltas: list[Tensor],
  273. state_steps: list[Tensor],
  274. *,
  275. lr: float,
  276. rho: float,
  277. eps: float,
  278. weight_decay: float,
  279. maximize: bool,
  280. differentiable: bool,
  281. capturable: bool,
  282. has_complex: bool,
  283. ) -> None:
  284. if differentiable:
  285. raise AssertionError("_foreach ops don't support autograd")
  286. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  287. if not torch.compiler.is_compiling() and capturable:
  288. capturable_supported_devices = _get_capturable_supported_devices(
  289. supports_xla=False
  290. )
  291. if not all(
  292. p.device.type == step.device.type
  293. and p.device.type in capturable_supported_devices
  294. for p, step in zip(params, state_steps, strict=True)
  295. ):
  296. raise AssertionError(
  297. f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  298. )
  299. if len(params) == 0:
  300. return
  301. lr = _to_scalar(lr)
  302. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  303. [params, grads, square_avgs, acc_deltas, state_steps] # type: ignore[list-item]
  304. )
  305. for (
  306. device_params_,
  307. device_grads_,
  308. device_square_avgs_,
  309. device_acc_deltas_,
  310. device_state_steps_,
  311. ), _ in grouped_tensors.values():
  312. device_params = cast(list[Tensor], device_params_)
  313. device_grads = cast(list[Tensor], device_grads_)
  314. device_square_avgs = cast(list[Tensor], device_square_avgs_)
  315. device_acc_deltas = cast(list[Tensor], device_acc_deltas_)
  316. device_state_steps = cast(list[Tensor], device_state_steps_)
  317. if has_complex:
  318. _view_as_real(
  319. device_params, device_grads, device_square_avgs, device_acc_deltas
  320. )
  321. # Update steps
  322. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  323. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  324. # wrapped it once now. The alpha is required to assure we go to the right overload.
  325. if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
  326. torch._foreach_add_(
  327. device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  328. )
  329. else:
  330. torch._foreach_add_(device_state_steps, 1)
  331. if maximize:
  332. device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
  333. if weight_decay != 0:
  334. # Reuse the intermediate memory (device_grads) already allocated for maximize
  335. if maximize:
  336. torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
  337. else:
  338. device_grads = torch._foreach_add( # type: ignore[assignment]
  339. device_grads, device_params, alpha=weight_decay
  340. )
  341. torch._foreach_mul_(device_square_avgs, rho)
  342. torch._foreach_addcmul_(
  343. device_square_avgs, device_grads, device_grads, value=1 - rho
  344. )
  345. std = torch._foreach_add(device_square_avgs, eps)
  346. torch._foreach_sqrt_(std)
  347. deltas = torch._foreach_add(device_acc_deltas, eps)
  348. torch._foreach_sqrt_(deltas)
  349. torch._foreach_div_(deltas, std)
  350. torch._foreach_mul_(deltas, device_grads)
  351. torch._foreach_mul_(device_acc_deltas, rho)
  352. torch._foreach_addcmul_(device_acc_deltas, deltas, deltas, value=1 - rho)
  353. # If LR is a tensor, the else branch will internally call item()
  354. # which will cause silent incorrectness if we are capturing
  355. if capturable and isinstance(lr, torch.Tensor):
  356. torch._foreach_mul_(deltas, -lr)
  357. torch._foreach_add_(device_params, deltas)
  358. else:
  359. torch._foreach_add_(device_params, deltas, alpha=-lr)
  360. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adadelta)
  361. def adadelta(
  362. params: list[Tensor],
  363. grads: list[Tensor],
  364. square_avgs: list[Tensor],
  365. acc_deltas: list[Tensor],
  366. state_steps: list[Tensor],
  367. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  368. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  369. capturable: bool = False,
  370. foreach: bool | None = None,
  371. differentiable: bool = False,
  372. has_complex: bool = False,
  373. *,
  374. lr: float,
  375. rho: float,
  376. eps: float,
  377. weight_decay: float,
  378. maximize: bool,
  379. ) -> None:
  380. r"""Functional API that performs Adadelta algorithm computation.
  381. See :class:`~torch.optim.Adadelta` for details.
  382. """
  383. # this check is slow during compilation, so we skip it
  384. # if it's strictly needed we can add this check back in dynamo
  385. if not torch.compiler.is_compiling() and not all(
  386. isinstance(t, torch.Tensor) for t in state_steps
  387. ):
  388. raise RuntimeError(
  389. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  390. )
  391. # We still respect when the user inputs False for foreach.
  392. if foreach is None:
  393. _, foreach = _default_to_fused_or_foreach(
  394. params, differentiable, use_fused=False
  395. )
  396. if foreach and torch.jit.is_scripting():
  397. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  398. if foreach and not torch.jit.is_scripting():
  399. func = _multi_tensor_adadelta
  400. else:
  401. func = _single_tensor_adadelta
  402. func(
  403. params,
  404. grads,
  405. square_avgs,
  406. acc_deltas,
  407. state_steps,
  408. lr=lr,
  409. rho=rho,
  410. eps=eps,
  411. weight_decay=weight_decay,
  412. maximize=maximize,
  413. differentiable=differentiable,
  414. capturable=capturable,
  415. has_complex=has_complex,
  416. )