UnaryOps.h 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/native/DispatchStub.h>
  4. #include <ATen/Generator.h>
  5. #include <c10/core/Scalar.h>
  6. namespace at {
  7. class Tensor;
  8. class TensorBase;
  9. struct TensorIteratorBase;
  10. }
  11. namespace at::native {
  12. using unary_fn = void(*)(TensorIteratorBase&);
  13. using unary_fn_with_scalar = void(*)(TensorIteratorBase&, const Scalar& a);
  14. inline namespace CPU_CAPABILITY {
  15. void conj_kernel(TensorIteratorBase &iter);
  16. void neg_kernel(TensorIteratorBase &iter);
  17. void reciprocal_kernel(TensorIteratorBase &iter);
  18. void rsqrt_kernel(TensorIteratorBase& iter);
  19. void sqrt_kernel(TensorIteratorBase& iter);
  20. } // namespace CPU_CAPABILITY
  21. DECLARE_DISPATCH(unary_fn, abs_stub)
  22. DECLARE_DISPATCH(unary_fn, angle_stub)
  23. DECLARE_DISPATCH(unary_fn, conj_physical_stub)
  24. DECLARE_DISPATCH(unary_fn, acos_stub)
  25. DECLARE_DISPATCH(unary_fn, acosh_stub)
  26. DECLARE_DISPATCH(unary_fn, asinh_stub)
  27. DECLARE_DISPATCH(unary_fn, atanh_stub)
  28. DECLARE_DISPATCH(unary_fn, asin_stub)
  29. DECLARE_DISPATCH(unary_fn, atan_stub)
  30. DECLARE_DISPATCH(unary_fn, bitwise_not_stub)
  31. DECLARE_DISPATCH(unary_fn, logical_not_stub)
  32. DECLARE_DISPATCH(unary_fn, ceil_stub)
  33. DECLARE_DISPATCH(unary_fn, cos_stub)
  34. DECLARE_DISPATCH(unary_fn, cosh_stub)
  35. DECLARE_DISPATCH(unary_fn, digamma_stub)
  36. DECLARE_DISPATCH(unary_fn, special_entr_stub)
  37. DECLARE_DISPATCH(unary_fn, special_erfcx_stub)
  38. DECLARE_DISPATCH(unary_fn, erf_stub)
  39. DECLARE_DISPATCH(unary_fn, erfc_stub)
  40. DECLARE_DISPATCH(unary_fn, erfinv_stub)
  41. DECLARE_DISPATCH(unary_fn, exp_stub)
  42. DECLARE_DISPATCH(unary_fn, exp2_stub)
  43. DECLARE_DISPATCH(unary_fn, expm1_stub)
  44. DECLARE_DISPATCH(unary_fn, floor_stub)
  45. DECLARE_DISPATCH(unary_fn, frac_stub)
  46. DECLARE_DISPATCH(unary_fn, frexp_stub)
  47. DECLARE_DISPATCH(unary_fn, i0_stub)
  48. DECLARE_DISPATCH(unary_fn, special_i0e_stub)
  49. DECLARE_DISPATCH(unary_fn, special_i1_stub)
  50. DECLARE_DISPATCH(unary_fn, special_i1e_stub)
  51. DECLARE_DISPATCH(unary_fn, log_stub)
  52. DECLARE_DISPATCH(unary_fn, log10_stub)
  53. DECLARE_DISPATCH(unary_fn, log1p_stub)
  54. DECLARE_DISPATCH(unary_fn, log2_stub)
  55. DECLARE_DISPATCH(unary_fn, special_ndtri_stub)
  56. DECLARE_DISPATCH(unary_fn, special_log_ndtr_stub)
  57. DECLARE_DISPATCH(unary_fn, neg_stub)
  58. DECLARE_DISPATCH(unary_fn, reciprocal_stub)
  59. DECLARE_DISPATCH(unary_fn, round_stub)
  60. DECLARE_DISPATCH(unary_fn, rsqrt_stub)
  61. DECLARE_DISPATCH(unary_fn, sigmoid_stub)
  62. DECLARE_DISPATCH(unary_fn_with_scalar, logit_stub)
  63. DECLARE_DISPATCH(unary_fn, sign_stub)
  64. DECLARE_DISPATCH(unary_fn, signbit_stub)
  65. DECLARE_DISPATCH(unary_fn, sgn_stub)
  66. DECLARE_DISPATCH(unary_fn, sin_stub)
  67. DECLARE_DISPATCH(unary_fn, sinc_stub)
  68. DECLARE_DISPATCH(unary_fn, sinh_stub)
  69. DECLARE_DISPATCH(unary_fn, sqrt_stub)
  70. DECLARE_DISPATCH(unary_fn, tan_stub)
  71. DECLARE_DISPATCH(unary_fn, tanh_stub)
  72. DECLARE_DISPATCH(unary_fn, trigamma_stub)
  73. DECLARE_DISPATCH(unary_fn, trunc_stub)
  74. DECLARE_DISPATCH(unary_fn, lgamma_stub)
  75. DECLARE_DISPATCH(unary_fn, special_airy_ai_stub)
  76. DECLARE_DISPATCH(unary_fn, special_bessel_j0_stub)
  77. DECLARE_DISPATCH(unary_fn, special_bessel_j1_stub)
  78. DECLARE_DISPATCH(unary_fn, special_bessel_y0_stub)
  79. DECLARE_DISPATCH(unary_fn, special_bessel_y1_stub)
  80. DECLARE_DISPATCH(unary_fn, special_modified_bessel_i0_stub)
  81. DECLARE_DISPATCH(unary_fn, special_modified_bessel_i1_stub)
  82. DECLARE_DISPATCH(unary_fn, special_modified_bessel_k0_stub)
  83. DECLARE_DISPATCH(unary_fn, special_modified_bessel_k1_stub)
  84. DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k0_stub)
  85. DECLARE_DISPATCH(unary_fn, special_scaled_modified_bessel_k1_stub)
  86. DECLARE_DISPATCH(unary_fn, special_spherical_bessel_j0_stub)
  87. // NB: these are actually defined in Distribution
  88. DECLARE_DISPATCH(void(*)(const TensorBase&, const TensorBase&, std::optional<Generator>), bernoulli_tensor_stub)
  89. DECLARE_DISPATCH(void(*)(const TensorBase&, const double, std::optional<Generator>), bernoulli_scalar_stub)
  90. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), cauchy_stub)
  91. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), exponential_stub)
  92. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, std::optional<Generator>), geometric_stub)
  93. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), log_normal_stub)
  94. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const double, const double, std::optional<Generator>), uniform_stub)
  95. DECLARE_DISPATCH(void(*)(const TensorBase&, const double, const double, std::optional<Generator>), normal_stub)
  96. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const uint64_t, const int64_t, std::optional<Generator>), random_from_to_stub)
  97. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_full_64_bits_range_stub)
  98. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, std::optional<Generator>), random_stub)
  99. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t, const double), kaiser_window_stub)
  100. DECLARE_DISPATCH(void(*)(TensorIteratorBase&, const int64_t), polygamma_stub)
  101. DECLARE_DISPATCH(
  102. void (*)(Tensor&, const Tensor&, int64_t, std::optional<Generator>),
  103. multinomial_with_replacement_stub)
  104. DECLARE_DISPATCH(
  105. void (*)(
  106. TensorIteratorBase&,
  107. std::optional<double>,
  108. std::optional<double>,
  109. std::optional<double>),
  110. nan_to_num_stub)
  111. DECLARE_DISPATCH(void (*)(TensorIteratorBase&, int64_t), round_decimals_stub)
  112. // Missing unary functions
  113. // digamma
  114. // lgamma
  115. // erfinv
  116. // clone
  117. // contiguous
  118. // zero
  119. } // namespace at::native
  120. #else
  121. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  122. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)