_functional.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # mypy: allow-untyped-defs
  2. r"""Functional interface."""
  3. import math
  4. from torch import Tensor
  5. from .adadelta import adadelta # type: ignore[attr-defined] # noqa: F401
  6. from .adagrad import _make_sparse, adagrad # type: ignore[attr-defined] # noqa: F401
  7. from .adam import adam # type: ignore[attr-defined] # noqa: F401
  8. from .adamax import adamax # type: ignore[attr-defined] # noqa: F401
  9. from .adamw import adamw # type: ignore[attr-defined] # noqa: F401
  10. from .asgd import asgd # type: ignore[attr-defined] # noqa: F401
  11. from .nadam import nadam # type: ignore[attr-defined] # noqa: F401
  12. from .radam import radam # type: ignore[attr-defined] # noqa: F401
  13. from .rmsprop import rmsprop # type: ignore[attr-defined] # noqa: F401
  14. from .rprop import rprop # type: ignore[attr-defined] # noqa: F401
  15. from .sgd import sgd # type: ignore[attr-defined] # noqa: F401
  16. # TODO: use foreach API in optim._functional to do all the computation
  17. def sparse_adam(
  18. params: list[Tensor],
  19. grads: list[Tensor],
  20. exp_avgs: list[Tensor],
  21. exp_avg_sqs: list[Tensor],
  22. state_steps: list[int],
  23. *,
  24. eps: float,
  25. beta1: float,
  26. beta2: float,
  27. lr: float,
  28. maximize: bool,
  29. ) -> None:
  30. r"""Functional API that performs Sparse Adam algorithm computation.
  31. See :class:`~torch.optim.SparseAdam` for details.
  32. """
  33. for i, param in enumerate(params):
  34. grad = grads[i]
  35. grad = grad if not maximize else -grad
  36. grad = grad.coalesce() # the update is non-linear so indices must be unique
  37. grad_indices = grad._indices()
  38. grad_values = grad._values()
  39. if grad_values.numel() == 0:
  40. # Skip update for empty grad
  41. continue
  42. size = grad.size()
  43. exp_avg = exp_avgs[i]
  44. exp_avg_sq = exp_avg_sqs[i]
  45. step = state_steps[i]
  46. def make_sparse(values):
  47. constructor = grad.new
  48. if grad_indices.dim() == 0 or values.dim() == 0:
  49. return constructor().resize_as_(grad)
  50. return constructor(grad_indices, values, size)
  51. # Decay the first and second moment running average coefficient
  52. # old <- b * old + (1 - b) * new
  53. # <==> old += (1 - b) * (new - old)
  54. old_exp_avg_values = exp_avg.sparse_mask(grad)._values()
  55. exp_avg_update_values = grad_values.sub(old_exp_avg_values).mul_(1 - beta1)
  56. exp_avg.add_(make_sparse(exp_avg_update_values))
  57. old_exp_avg_sq_values = exp_avg_sq.sparse_mask(grad)._values()
  58. exp_avg_sq_update_values = (
  59. grad_values.pow(2).sub_(old_exp_avg_sq_values).mul_(1 - beta2)
  60. )
  61. exp_avg_sq.add_(make_sparse(exp_avg_sq_update_values))
  62. # Dense addition again is intended, avoiding another sparse_mask
  63. numer = exp_avg_update_values.add_(old_exp_avg_values)
  64. exp_avg_sq_update_values.add_(old_exp_avg_sq_values)
  65. denom = exp_avg_sq_update_values.sqrt_().add_(eps)
  66. del exp_avg_update_values, exp_avg_sq_update_values
  67. bias_correction1 = 1 - beta1**step
  68. bias_correction2 = 1 - beta2**step
  69. step_size = lr * math.sqrt(bias_correction2) / bias_correction1
  70. param.add_(make_sparse(-step_size * numer.div_(denom)))