complex.h 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <complex>
  4. #include <c10/macros/Macros.h>
  5. #include <c10/util/Half.h>
  6. #include <torch/headeronly/util/complex.h>
  7. // std functions
  8. //
  9. // The implementation of these functions also follow the design of C++20
  10. namespace std {
  11. template <typename T>
  12. constexpr T real(const c10::complex<T>& z) {
  13. return z.real();
  14. }
  15. template <typename T>
  16. constexpr T imag(const c10::complex<T>& z) {
  17. return z.imag();
  18. }
  19. template <typename T>
  20. C10_HOST_DEVICE T abs(const c10::complex<T>& z) {
  21. #if defined(__CUDACC__) || defined(__HIPCC__)
  22. return thrust::abs(static_cast<thrust::complex<T>>(z));
  23. #else
  24. return std::abs(static_cast<std::complex<T>>(z));
  25. #endif
  26. }
  27. #if defined(USE_ROCM)
  28. #define ROCm_Bug(x)
  29. #else
  30. #define ROCm_Bug(x) x
  31. #endif
  32. template <typename T>
  33. C10_HOST_DEVICE T arg(const c10::complex<T>& z) {
  34. return ROCm_Bug(std)::atan2(std::imag(z), std::real(z));
  35. }
  36. #undef ROCm_Bug
  37. template <typename T>
  38. constexpr T norm(const c10::complex<T>& z) {
  39. return z.real() * z.real() + z.imag() * z.imag();
  40. }
  41. // For std::conj, there are other versions of it:
  42. // constexpr std::complex<float> conj( float z );
  43. // template< class DoubleOrInteger >
  44. // constexpr std::complex<double> conj( DoubleOrInteger z );
  45. // constexpr std::complex<long double> conj( long double z );
  46. // These are not implemented
  47. // TODO(@zasdfgbnm): implement them as c10::conj
  48. template <typename T>
  49. constexpr c10::complex<T> conj(const c10::complex<T>& z) {
  50. return c10::complex<T>(z.real(), -z.imag());
  51. }
  52. // Thrust does not have complex --> complex version of thrust::proj,
  53. // so this function is not implemented at c10 right now.
  54. // TODO(@zasdfgbnm): implement it by ourselves
  55. // There is no c10 version of std::polar, because std::polar always
  56. // returns std::complex. Use c10::polar instead;
  57. } // namespace std
  58. #define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
  59. // math functions are included in a separate file
  60. #include <c10/util/complex_math.h> // IWYU pragma: keep
  61. // utilities for complex types
  62. #include <c10/util/complex_utils.h> // IWYU pragma: keep
  63. #undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H
  64. #else
  65. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  66. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)