MathConstants.h 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <c10/macros/Macros.h>
  4. #include <c10/util/BFloat16.h>
  5. #include <c10/util/Half.h>
  6. C10_CLANG_DIAGNOSTIC_PUSH()
  7. #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
  8. C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
  9. #endif
  10. namespace c10 {
  11. // TODO: Replace me with inline constexpr variable when C++17 becomes available
  12. namespace detail {
  13. template <typename T>
  14. C10_HOST_DEVICE inline constexpr T e() {
  15. return static_cast<T>(2.718281828459045235360287471352662);
  16. }
  17. template <typename T>
  18. C10_HOST_DEVICE inline constexpr T euler() {
  19. return static_cast<T>(0.577215664901532860606512090082402);
  20. }
  21. template <typename T>
  22. C10_HOST_DEVICE inline constexpr T frac_1_pi() {
  23. return static_cast<T>(0.318309886183790671537767526745028);
  24. }
  25. template <typename T>
  26. C10_HOST_DEVICE inline constexpr T frac_1_sqrt_pi() {
  27. return static_cast<T>(0.564189583547756286948079451560772);
  28. }
  29. template <typename T>
  30. C10_HOST_DEVICE inline constexpr T frac_sqrt_2() {
  31. return static_cast<T>(0.707106781186547524400844362104849);
  32. }
  33. template <typename T>
  34. C10_HOST_DEVICE inline constexpr T frac_sqrt_3() {
  35. return static_cast<T>(0.577350269189625764509148780501957);
  36. }
  37. template <typename T>
  38. C10_HOST_DEVICE inline constexpr T golden_ratio() {
  39. return static_cast<T>(1.618033988749894848204586834365638);
  40. }
  41. template <typename T>
  42. C10_HOST_DEVICE inline constexpr T ln_10() {
  43. return static_cast<T>(2.302585092994045684017991454684364);
  44. }
  45. template <typename T>
  46. C10_HOST_DEVICE inline constexpr T ln_2() {
  47. return static_cast<T>(0.693147180559945309417232121458176);
  48. }
  49. template <typename T>
  50. C10_HOST_DEVICE inline constexpr T log_10_e() {
  51. return static_cast<T>(0.434294481903251827651128918916605);
  52. }
  53. template <typename T>
  54. C10_HOST_DEVICE inline constexpr T log_2_e() {
  55. return static_cast<T>(1.442695040888963407359924681001892);
  56. }
  57. template <typename T>
  58. C10_HOST_DEVICE inline constexpr T pi() {
  59. return static_cast<T>(3.141592653589793238462643383279502);
  60. }
  61. template <typename T>
  62. C10_HOST_DEVICE inline constexpr T sqrt_2() {
  63. return static_cast<T>(1.414213562373095048801688724209698);
  64. }
  65. template <typename T>
  66. C10_HOST_DEVICE inline constexpr T sqrt_3() {
  67. return static_cast<T>(1.732050807568877293527446341505872);
  68. }
  69. template <>
  70. C10_HOST_DEVICE inline constexpr BFloat16 pi<BFloat16>() {
  71. // According to
  72. // https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Special_values
  73. // pi is encoded as 4049
  74. return BFloat16(0x4049, BFloat16::from_bits());
  75. }
  76. template <>
  77. C10_HOST_DEVICE inline constexpr Half pi<Half>() {
  78. return Half(0x4248, Half::from_bits());
  79. }
  80. } // namespace detail
  81. template <typename T>
  82. constexpr T e = c10::detail::e<T>();
  83. template <typename T>
  84. constexpr T euler = c10::detail::euler<T>();
  85. template <typename T>
  86. constexpr T frac_1_pi = c10::detail::frac_1_pi<T>();
  87. template <typename T>
  88. constexpr T frac_1_sqrt_pi = c10::detail::frac_1_sqrt_pi<T>();
  89. template <typename T>
  90. constexpr T frac_sqrt_2 = c10::detail::frac_sqrt_2<T>();
  91. template <typename T>
  92. constexpr T frac_sqrt_3 = c10::detail::frac_sqrt_3<T>();
  93. template <typename T>
  94. constexpr T golden_ratio = c10::detail::golden_ratio<T>();
  95. template <typename T>
  96. constexpr T ln_10 = c10::detail::ln_10<T>();
  97. template <typename T>
  98. constexpr T ln_2 = c10::detail::ln_2<T>();
  99. template <typename T>
  100. constexpr T log_10_e = c10::detail::log_10_e<T>();
  101. template <typename T>
  102. constexpr T log_2_e = c10::detail::log_2_e<T>();
  103. template <typename T>
  104. constexpr T pi = c10::detail::pi<T>();
  105. template <typename T>
  106. constexpr T sqrt_2 = c10::detail::sqrt_2<T>();
  107. template <typename T>
  108. constexpr T sqrt_3 = c10::detail::sqrt_3<T>();
  109. } // namespace c10
  110. C10_CLANG_DIAGNOSTIC_POP()
  111. #else
  112. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  113. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)