functional_adagrad.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # mypy: allow-untyped-defs
  2. import torch
  3. import torch.optim._functional as F
  4. from torch import Tensor
  5. from torch.distributed.optim._deprecation_warning import (
  6. _scripted_functional_optimizer_deprecation_warning,
  7. )
  8. __all__: list[str] = []
  9. # Define a TorchScript compatible Functional Adagrad Optimizer
  10. # where we use these optimizer in a functional way.
  11. # Instead of using the `param.grad` when updating parameters,
  12. # we explicitly let the user pass gradients to the `step` function
  13. # this is so that we could separate the gradients and parameters
  14. # and allow multithreaded trainer to update the parameters
  15. # without data traces on accumulating to the same .grad.
  16. # NOTE: This should be only used by distributed optimizer internals
  17. # and not meant to expose to the user.
  18. @torch.jit.script
  19. class _FunctionalAdagrad:
  20. def __init__(
  21. self,
  22. params: list[Tensor],
  23. lr: float = 1e-2,
  24. lr_decay: float = 0.0,
  25. weight_decay: float = 0.0,
  26. initial_accumulator_value: float = 0.0,
  27. warmup_lr_multiplier: float = 1.0,
  28. warmup_num_iters: float = 0.0,
  29. eps: float = 1e-10,
  30. coalesce_grad: bool = True,
  31. foreach: bool = False,
  32. fused: bool = False,
  33. maximize: bool = False,
  34. _allow_empty_param_list: bool = False,
  35. ):
  36. _scripted_functional_optimizer_deprecation_warning(stacklevel=2)
  37. self.defaults = {
  38. "lr": lr,
  39. "lr_decay": lr_decay,
  40. "eps": eps,
  41. "weight_decay": weight_decay,
  42. "initial_accumulator_value": initial_accumulator_value,
  43. "warmup_lr_multiplier": warmup_lr_multiplier,
  44. "warmup_num_iters": warmup_num_iters,
  45. }
  46. self.coalesce_grad = coalesce_grad
  47. self.foreach = foreach
  48. self.fused = fused
  49. self.maximize = maximize
  50. self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
  51. if len(params) == 0 and not _allow_empty_param_list:
  52. raise ValueError("optimizer got an empty parameter list")
  53. # NOTE: we only have one param_group and don't allow user to add additional
  54. # param group as it's not a common use case.
  55. self.param_group = {"params": params}
  56. # TODO: no union or any types in TorchScript, make step a scalar tensor instead
  57. # This is also needed by if we want to share_memory on the step across processes
  58. for p in self.param_group["params"]:
  59. self.state[p] = {
  60. "sum": torch.full_like(p.data, initial_accumulator_value),
  61. "step": torch.tensor(0.0),
  62. }
  63. def step(self, gradients: list[Tensor | None]):
  64. params = self.param_group["params"]
  65. params_with_grad = []
  66. grads = []
  67. state_sums = []
  68. state_steps: list[Tensor] = []
  69. if len(params) != len(gradients):
  70. raise ValueError(
  71. "the gradients passed in does not equal to the size of the parameters!"
  72. + f"Params length: {len(params)}. "
  73. + f"Gradients length: {len(gradients)}"
  74. )
  75. has_sparse_grad, has_complex = False, False
  76. for param, gradient in zip(self.param_group["params"], gradients):
  77. if gradient is not None:
  78. has_sparse_grad |= gradient.is_sparse
  79. has_complex |= torch.is_complex(param)
  80. params_with_grad.append(param)
  81. grads.append(gradient)
  82. state = self.state[param]
  83. state_sums.append(state["sum"])
  84. state_steps.append(state["step"])
  85. with torch.no_grad():
  86. F.adagrad(
  87. params,
  88. grads,
  89. state_sums,
  90. state_steps,
  91. lr=self.defaults["lr"],
  92. weight_decay=self.defaults["weight_decay"],
  93. lr_decay=self.defaults["lr_decay"],
  94. eps=self.defaults["eps"],
  95. has_sparse_grad=has_sparse_grad,
  96. foreach=self.foreach,
  97. maximize=self.maximize,
  98. has_complex=has_complex,
  99. fused=self.fused,
  100. grad_scale=None,
  101. found_inf=None,
  102. )