adamw.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # mypy: allow-untyped-defs
  2. from torch import Tensor
  3. from .adam import Adam, adam
  4. from .optimizer import (
  5. _capturable_doc,
  6. _differentiable_doc,
  7. _foreach_doc,
  8. _fused_doc,
  9. _maximize_doc,
  10. _params_doc,
  11. ParamsT,
  12. )
  13. __all__ = ["AdamW", "adamw"]
  14. class AdamW(Adam):
  15. def __init__(
  16. self,
  17. params: ParamsT,
  18. lr: float | Tensor = 1e-3,
  19. betas: tuple[float | Tensor, float | Tensor] = (0.9, 0.999),
  20. eps: float = 1e-8,
  21. weight_decay: float = 1e-2,
  22. amsgrad: bool = False,
  23. *,
  24. maximize: bool = False,
  25. foreach: bool | None = None,
  26. capturable: bool = False,
  27. differentiable: bool = False,
  28. fused: bool | None = None,
  29. ) -> None:
  30. super().__init__(
  31. params,
  32. lr,
  33. betas,
  34. eps,
  35. weight_decay,
  36. amsgrad,
  37. foreach=foreach,
  38. maximize=maximize,
  39. capturable=capturable,
  40. differentiable=differentiable,
  41. fused=fused,
  42. decoupled_weight_decay=True,
  43. )
  44. # Preserve decoupled_weight_decay from AdamW for backwards compatibility. The following
  45. # guarantees that decoupled_weight_decay will always be True for loading any state into
  46. # AdamW
  47. def __setstate__(self, state):
  48. super().__setstate__(state)
  49. for group in self.param_groups:
  50. group["decoupled_weight_decay"] = True
  51. AdamW.__doc__ = (
  52. r"""Implements AdamW algorithm, where weight decay does not accumulate in the momentum nor variance.
  53. .. math::
  54. \begin{aligned}
  55. &\rule{110mm}{0.4pt} \\
  56. &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2
  57. \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)},
  58. \: \epsilon \text{ (epsilon)} \\
  59. &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad},
  60. \: \textit{maximize} \\
  61. &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0
  62. \text{ ( second moment)}, \: v_0^{max}\leftarrow 0 \\[-1.ex]
  63. &\rule{110mm}{0.4pt} \\
  64. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  65. &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
  66. &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
  67. &\hspace{5mm}\textbf{else} \\
  68. &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  69. &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\
  70. &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  71. &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  72. &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
  73. &\hspace{5mm}\textbf{if} \: amsgrad \\
  74. &\hspace{10mm} v_t^{max} \leftarrow \mathrm{max}(v_{t-1}^{max},v_t) \\
  75. &\hspace{10mm}\widehat{v_t} \leftarrow v_t^{max}/\big(1-\beta_2^t \big) \\
  76. &\hspace{5mm}\textbf{else} \\
  77. &\hspace{10mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
  78. &\hspace{5mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/
  79. \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
  80. &\rule{110mm}{0.4pt} \\[-1.ex]
  81. &\bf{return} \: \theta_t \\[-1.ex]
  82. &\rule{110mm}{0.4pt} \\[-1.ex]
  83. \end{aligned}
  84. For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_.
  85. """
  86. + rf"""
  87. Args:
  88. {_params_doc}
  89. lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
  90. is not yet supported for all our implementations. Please use a float
  91. LR if you are not also specifying fused=True or capturable=True.
  92. betas (tuple[Union[float, Tensor], Union[float, Tensor]], optional):
  93. coefficients used for computing running averages of gradient and
  94. its square. If a tensor is provided, must be 1-element. (default: (0.9, 0.999))
  95. eps (float, optional): term added to the denominator to improve
  96. numerical stability (default: 1e-8)
  97. weight_decay (float, optional): weight decay coefficient (default: 1e-2)
  98. amsgrad (bool, optional): whether to use the AMSGrad variant of this
  99. algorithm from the paper `On the Convergence of Adam and Beyond`_
  100. (default: False)
  101. {_maximize_doc}
  102. {_foreach_doc}
  103. {_capturable_doc}
  104. {_differentiable_doc}
  105. {_fused_doc}
  106. .. Note::
  107. A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`.
  108. .. _Decoupled Weight Decay Regularization:
  109. https://arxiv.org/abs/1711.05101
  110. .. _On the Convergence of Adam and Beyond:
  111. https://openreview.net/forum?id=ryQu7f-RZ
  112. """
  113. )
  114. # @_disable_dynamo_if_unsupported logic occurs in the decorator that's applied to F.adam
  115. def adamw(
  116. params: list[Tensor],
  117. grads: list[Tensor],
  118. exp_avgs: list[Tensor],
  119. exp_avg_sqs: list[Tensor],
  120. max_exp_avg_sqs: list[Tensor],
  121. state_steps: list[Tensor],
  122. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  123. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  124. foreach: bool | None = None,
  125. capturable: bool = False,
  126. differentiable: bool = False,
  127. fused: bool | None = None,
  128. grad_scale: Tensor | None = None,
  129. found_inf: Tensor | None = None,
  130. has_complex: bool = False,
  131. *,
  132. amsgrad: bool,
  133. beta1: float | Tensor,
  134. beta2: float | Tensor,
  135. lr: float | Tensor,
  136. weight_decay: float,
  137. eps: float,
  138. maximize: bool,
  139. ) -> None:
  140. r"""Functional API that performs AdamW algorithm computation.
  141. See :class:`~torch.optim.AdamW` for details.
  142. """
  143. adam(
  144. params,
  145. grads,
  146. exp_avgs,
  147. exp_avg_sqs,
  148. max_exp_avg_sqs,
  149. state_steps,
  150. foreach=foreach,
  151. capturable=capturable,
  152. differentiable=differentiable,
  153. fused=fused,
  154. grad_scale=grad_scale,
  155. found_inf=found_inf,
  156. has_complex=has_complex,
  157. amsgrad=amsgrad,
  158. beta1=beta1,
  159. beta2=beta2,
  160. lr=lr,
  161. weight_decay=weight_decay,
  162. eps=eps,
  163. maximize=maximize,
  164. decoupled_weight_decay=True,
  165. )