BinaryOps.h 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/TensorBase.h>
  4. #include <ATen/native/DispatchStub.h>
  5. #include <c10/core/Scalar.h>
  6. #include <c10/util/TypeSafeSignMath.h>
  7. #include <ATen/native/TensorIterator.h>
  8. namespace at {
  9. struct TensorIterator;
  10. struct TensorIteratorBase;
  11. }
  12. namespace at::native {
  13. inline void alpha_check(const ScalarType dtype, const Scalar& alpha) {
  14. TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool,
  15. "Boolean alpha only supported for Boolean results.");
  16. TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
  17. || alpha.isIntegral(true),
  18. "For integral input tensors, argument alpha must not be a floating point number.");
  19. TORCH_CHECK(isComplexType(dtype) || !alpha.isComplex(),
  20. "For non-complex input tensors, argument alpha must not be a complex number.")
  21. }
  22. // Basic checking for all sub functions.
  23. inline void sub_check(const TensorBase& self, const TensorBase& other) {
  24. TORCH_CHECK(self.scalar_type() != kBool || other.scalar_type() != kBool,
  25. "Subtraction, the `-` operator, with two bool tensors is not supported. "
  26. "Use the `^` or `logical_xor()` operator instead.")
  27. TORCH_CHECK(self.scalar_type() != kBool && other.scalar_type() != kBool,
  28. "Subtraction, the `-` operator, with a bool tensor is not supported. "
  29. "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
  30. }
  31. inline void sub_check(const TensorBase& self, const Scalar& scalar) {
  32. TORCH_CHECK(self.scalar_type() != kBool || !scalar.isBoolean(),
  33. "Subtraction, the `-` operator, with two bool tensors is not supported. "
  34. "Use the `^` or `logical_xor()` operator instead.")
  35. TORCH_CHECK(self.scalar_type() != kBool && !scalar.isBoolean(),
  36. "Subtraction, the `-` operator, with a bool tensor is not supported. "
  37. "If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.");
  38. }
  39. using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
  40. using structured_binary_fn_double = void(*)(TensorIteratorBase&, double);
  41. using structured_binary_fn = void(*)(TensorIteratorBase&);
  42. using binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
  43. using binary_fn_double = void(*)(TensorIterator&, double);
  44. using binary_fn = void(*)(TensorIterator&);
  45. using binary_clamp_fn_alpha =
  46. void(*)(TensorIterator&, const Scalar& alpha, const Scalar& min_val, const Scalar& max_val);
  47. using ldexp_fn = void(*)(TensorIteratorBase&);
  48. // NB: codegenned
  49. DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub)
  50. DECLARE_DISPATCH(binary_clamp_fn_alpha, add_clamp_stub)
  51. DECLARE_DISPATCH(structured_binary_fn_alpha, sub_stub)
  52. DECLARE_DISPATCH(structured_binary_fn, mul_stub)
  53. DECLARE_DISPATCH(structured_binary_fn, div_true_stub)
  54. DECLARE_DISPATCH(structured_binary_fn, div_floor_stub)
  55. DECLARE_DISPATCH(structured_binary_fn, div_trunc_stub)
  56. DECLARE_DISPATCH(structured_binary_fn, atan2_stub)
  57. DECLARE_DISPATCH(structured_binary_fn, remainder_stub)
  58. DECLARE_DISPATCH(structured_binary_fn, bitwise_and_stub)
  59. DECLARE_DISPATCH(structured_binary_fn, bitwise_or_stub)
  60. DECLARE_DISPATCH(structured_binary_fn, bitwise_xor_stub)
  61. DECLARE_DISPATCH(structured_binary_fn, lshift_stub)
  62. DECLARE_DISPATCH(structured_binary_fn, rshift_stub)
  63. DECLARE_DISPATCH(binary_fn, logical_xor_stub)
  64. DECLARE_DISPATCH(binary_fn, logical_and_stub)
  65. DECLARE_DISPATCH(binary_fn, logical_or_stub)
  66. DECLARE_DISPATCH(structured_binary_fn, lt_stub)
  67. DECLARE_DISPATCH(structured_binary_fn, le_stub)
  68. DECLARE_DISPATCH(structured_binary_fn, gt_stub)
  69. DECLARE_DISPATCH(structured_binary_fn, ge_stub)
  70. DECLARE_DISPATCH(structured_binary_fn, eq_stub)
  71. DECLARE_DISPATCH(structured_binary_fn, ne_stub)
  72. DECLARE_DISPATCH(binary_fn, max_elementwise_stub)
  73. DECLARE_DISPATCH(binary_fn, min_elementwise_stub)
  74. DECLARE_DISPATCH(structured_binary_fn, maximum_stub)
  75. DECLARE_DISPATCH(structured_binary_fn, minimum_stub)
  76. DECLARE_DISPATCH(structured_binary_fn, fmax_stub)
  77. DECLARE_DISPATCH(structured_binary_fn, fmin_stub)
  78. DECLARE_DISPATCH(structured_binary_fn_double, smooth_l1_stub)
  79. DECLARE_DISPATCH(binary_fn_double, huber_stub)
  80. DECLARE_DISPATCH(structured_binary_fn, sigmoid_backward_stub)
  81. DECLARE_DISPATCH(binary_fn_alpha, logit_backward_stub)
  82. DECLARE_DISPATCH(structured_binary_fn, tanh_backward_stub)
  83. DECLARE_DISPATCH(structured_binary_fn, mse_stub)
  84. DECLARE_DISPATCH(structured_binary_fn, fmod_stub)
  85. DECLARE_DISPATCH(structured_binary_fn, logaddexp_stub)
  86. DECLARE_DISPATCH(structured_binary_fn, logaddexp2_stub)
  87. DECLARE_DISPATCH(structured_binary_fn, gcd_stub)
  88. DECLARE_DISPATCH(structured_binary_fn, lcm_stub)
  89. DECLARE_DISPATCH(structured_binary_fn, hypot_stub)
  90. DECLARE_DISPATCH(structured_binary_fn, igamma_stub)
  91. DECLARE_DISPATCH(structured_binary_fn, igammac_stub)
  92. DECLARE_DISPATCH(structured_binary_fn, nextafter_stub)
  93. DECLARE_DISPATCH(structured_binary_fn, heaviside_stub)
  94. DECLARE_DISPATCH(structured_binary_fn, copysign_stub)
  95. DECLARE_DISPATCH(structured_binary_fn, xlogy_stub)
  96. DECLARE_DISPATCH(structured_binary_fn, xlog1py_stub)
  97. DECLARE_DISPATCH(structured_binary_fn, zeta_stub)
  98. DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_t_stub)
  99. DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_u_stub)
  100. DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_v_stub)
  101. DECLARE_DISPATCH(structured_binary_fn, chebyshev_polynomial_w_stub)
  102. DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_h_stub)
  103. DECLARE_DISPATCH(structured_binary_fn, hermite_polynomial_he_stub)
  104. DECLARE_DISPATCH(structured_binary_fn, laguerre_polynomial_l_stub)
  105. DECLARE_DISPATCH(structured_binary_fn, legendre_polynomial_p_stub)
  106. DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_t_stub)
  107. DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_u_stub)
  108. DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_v_stub)
  109. DECLARE_DISPATCH(structured_binary_fn, shifted_chebyshev_polynomial_w_stub)
  110. DECLARE_DISPATCH(ldexp_fn, ldexp_stub)
  111. } // namespace at::native
  112. #else
  113. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  114. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)