Pow.h 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/native/DispatchStub.h>
  4. namespace c10 {
  5. class Scalar;
  6. }
  7. namespace at {
  8. struct TensorIterator;
  9. struct TensorIteratorBase;
  10. namespace native {
  11. #if defined(__CUDACC__) || defined(__HIPCC__)
  12. #define HOST_DEVICE __host__ __device__
  13. #else
  14. #define HOST_DEVICE
  15. #endif
  16. // integral power in pytorch allows for negative exponents, giving truncated integral results.
  17. // e.g. since 2**-1==0.5, the truncated integral result is zero. 1**negative_exponent is the
  18. // only non-zero result.
  19. template <class T,
  20. std::enable_if_t<std::is_integral_v<T>, T>* = nullptr>
  21. inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) {
  22. T result = 1;
  23. while (b) {
  24. if (b & 1) {
  25. result *= a;
  26. }
  27. b /= 2;
  28. a *= a;
  29. }
  30. return result;
  31. }
  32. template <class T,
  33. std::enable_if_t<std::is_integral_v<T> && !std::is_signed_v<T>, T>* = nullptr>
  34. inline HOST_DEVICE T powi(T a, T b) {
  35. return powi_impl(a, b);
  36. }
  37. template <class T,
  38. std::enable_if_t<std::is_integral_v<T> && std::is_signed_v<T>, T>* = nullptr>
  39. inline HOST_DEVICE T powi(T a, T b) {
  40. if ( b < 0 ) {
  41. if ( a == 1 ) {
  42. return 1;
  43. } else if ( a == -1 ) {
  44. auto negative = (-b) % static_cast<T>(2);
  45. return negative ? -1 : 1;
  46. } else {
  47. return 0;
  48. }
  49. }
  50. return powi_impl(a, b);
  51. }
  52. using pow_tensor_tensor_fn = void (*)(TensorIteratorBase&);
  53. using pow_tensor_scalar_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
  54. DECLARE_DISPATCH(pow_tensor_tensor_fn, pow_tensor_tensor_stub)
  55. DECLARE_DISPATCH(pow_tensor_scalar_fn, pow_tensor_scalar_stub)
  56. } // namespace native
  57. } // namespace at
  58. #else
  59. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  60. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)