FusedAdam.h 937 B

1234567891011121314151617181920212223242526272829303132
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #include <ATen/core/Tensor.h>
  3. #include <ATen/native/DispatchStub.h>
  4. namespace at::native {
  5. enum class ADAM_MODE : uint8_t { ORIGINAL = 0, ADAMW = 1 };
  6. using fused_adam_fn = void (*)(
  7. const at::Tensor& param,
  8. const at::Tensor& grad,
  9. const at::Tensor& exp_avg,
  10. const at::Tensor& exp_avg_sq,
  11. const at::Tensor& max_exp_avg_sq,
  12. const at::Tensor& state_step,
  13. const double lr,
  14. const double beta1,
  15. const double beta2,
  16. const double weight_decay,
  17. const double eps,
  18. const bool amsgrad,
  19. const bool maximize,
  20. const float* grad_scale_ptr,
  21. const ADAM_MODE);
  22. DECLARE_DISPATCH(fused_adam_fn, fused_adam_stub)
  23. } // namespace at::native
  24. #else
  25. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  26. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)