| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- #include <c10/util/BFloat16.h>
- #include <c10/util/Half.h>
- C10_CLANG_DIAGNOSTIC_PUSH()
- #if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion")
- C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion")
- #endif
- namespace c10 {
- template <typename T>
- struct is_reduced_floating_point
- : std::integral_constant<
- bool,
- std::is_same_v<T, c10::Half> || std::is_same_v<T, c10::BFloat16>> {};
- template <typename T>
- constexpr bool is_reduced_floating_point_v =
- is_reduced_floating_point<T>::value;
- } // namespace c10
- namespace std {
- #if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED)
- using c10::is_reduced_floating_point;
- using c10::is_reduced_floating_point_v;
- #endif // !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED)
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T acos(T a) {
- return std::acos(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T asin(T a) {
- return std::asin(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T atan(T a) {
- return std::atan(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T atanh(T a) {
- return std::atanh(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T erf(T a) {
- return std::erf(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T erfc(T a) {
- return std::erfc(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T exp(T a) {
- return std::exp(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T expm1(T a) {
- return std::expm1(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline bool isfinite(T a) {
- return std::isfinite(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T log(T a) {
- return std::log(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T log10(T a) {
- return std::log10(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T log1p(T a) {
- return std::log1p(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T log2(T a) {
- return std::log2(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T ceil(T a) {
- return std::ceil(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T cos(T a) {
- return std::cos(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T floor(T a) {
- return std::floor(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T nearbyint(T a) {
- return std::nearbyint(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T sin(T a) {
- return std::sin(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T tan(T a) {
- return std::tan(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T sinh(T a) {
- return std::sinh(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T cosh(T a) {
- return std::cosh(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T tanh(T a) {
- return std::tanh(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T trunc(T a) {
- return std::trunc(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T lgamma(T a) {
- return std::lgamma(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T sqrt(T a) {
- return std::sqrt(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T rsqrt(T a) {
- return 1.0 / std::sqrt(float(a));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T abs(T a) {
- return std::abs(float(a));
- }
- #if defined(_MSC_VER) && defined(__CUDACC__)
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T pow(T a, double b) {
- return std::pow(float(a), float(b));
- }
- #else
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T pow(T a, double b) {
- return std::pow(float(a), b);
- }
- #endif
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T pow(T a, T b) {
- return std::pow(float(a), float(b));
- }
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- inline T fmod(T a, T b) {
- return std::fmod(float(a), float(b));
- }
- /*
- The following function is inspired from the implementation in `musl`
- Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT
- ----------------------------------------------------------------------
- Copyright © 2005-2020 Rich Felker, et al.
- Permission is hereby granted, free of charge, to any person obtaining
- a copy of this software and associated documentation files (the
- "Software"), to deal in the Software without restriction, including
- without limitation the rights to use, copy, modify, merge, publish,
- distribute, sublicense, and/or sell copies of the Software, and to
- permit persons to whom the Software is furnished to do so, subject to
- the following conditions:
- The above copyright notice and this permission notice shall be
- included in all copies or substantial portions of the Software.
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- ----------------------------------------------------------------------
- */
- template <
- typename T,
- typename std::enable_if_t<c10::is_reduced_floating_point_v<T>, int> = 0>
- C10_HOST_DEVICE inline T nextafter(T from, T to) {
- // Reference:
- // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c
- using int_repr_t = uint16_t;
- constexpr uint8_t bits = 16;
- union {
- T f;
- int_repr_t i;
- } ufrom = {from}, uto = {to};
- // get a mask to get the sign bit i.e. MSB
- int_repr_t sign_mask = int_repr_t{1} << (bits - 1);
- // short-circuit: if either is NaN, return NaN
- if (from != from || to != to) {
- return from + to;
- }
- // short-circuit: if they are exactly the same.
- if (ufrom.i == uto.i) {
- return from;
- }
- // mask the sign-bit to zero i.e. positive
- // equivalent to abs(x)
- int_repr_t abs_from = ufrom.i & ~sign_mask;
- int_repr_t abs_to = uto.i & ~sign_mask;
- if (abs_from == 0) {
- // if both are zero but with different sign,
- // preserve the sign of `to`.
- if (abs_to == 0) {
- return to;
- }
- // smallest subnormal with sign of `to`.
- ufrom.i = (uto.i & sign_mask) | int_repr_t{1};
- return ufrom.f;
- }
- // if abs(from) > abs(to) or sign(from) != sign(to)
- if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) {
- ufrom.i--;
- } else {
- ufrom.i++;
- }
- return ufrom.f;
- }
- } // namespace std
- C10_CLANG_DIAGNOSTIC_POP()
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|