AmpKernels.h 871 B

123456789101112131415161718192021222324252627282930313233
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/native/DispatchStub.h>
  4. #include <ATen/core/ATen_fwd.h>
  5. namespace at {
  6. class Tensor;
  7. namespace native {
  8. using _amp_foreach_non_finite_check_and_unscale_cpu__fn = void (*)(
  9. TensorList,
  10. Tensor&,
  11. const Tensor&);
  12. using _amp_update_scale_cpu__fn = Tensor& (*)(
  13. Tensor&,
  14. Tensor&,
  15. const Tensor&,
  16. double,
  17. double,
  18. int64_t);
  19. DECLARE_DISPATCH(_amp_foreach_non_finite_check_and_unscale_cpu__fn, _amp_foreach_non_finite_check_and_unscale_cpu_stub)
  20. DECLARE_DISPATCH(_amp_update_scale_cpu__fn, _amp_update_scale_cpu_stub)
  21. } // namespace native
  22. } // namespace at
  23. #else
  24. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  25. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)