complex.h 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. /*
  3. pybind11/complex.h: Complex number support
  4. Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
  5. All rights reserved. Use of this source code is governed by a
  6. BSD-style license that can be found in the LICENSE file.
  7. */
  8. #pragma once
  9. #include "pybind11.h"
  10. #include <complex>
  11. /// glibc defines I as a macro which breaks things, e.g., boost template names
  12. #ifdef I
  13. # undef I
  14. #endif
  15. PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
  16. template <typename T>
  17. struct format_descriptor<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
  18. static constexpr const char c = format_descriptor<T>::c;
  19. static constexpr const char value[3] = {'Z', c, '\0'};
  20. static std::string format() { return std::string(value); }
  21. };
  22. #ifndef PYBIND11_CPP17
  23. template <typename T>
  24. constexpr const char
  25. format_descriptor<std::complex<T>,
  26. detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];
  27. #endif
  28. PYBIND11_NAMESPACE_BEGIN(detail)
  29. template <typename T>
  30. struct is_fmt_numeric<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
  31. static constexpr bool value = true;
  32. static constexpr int index = is_fmt_numeric<T>::index + 3;
  33. };
  34. template <typename T>
  35. class type_caster<std::complex<T>> {
  36. public:
  37. bool load(handle src, bool convert) {
  38. if (!src) {
  39. return false;
  40. }
  41. if (!convert && !PyComplex_Check(src.ptr())) {
  42. return false;
  43. }
  44. Py_complex result = PyComplex_AsCComplex(src.ptr());
  45. if (result.real == -1.0 && PyErr_Occurred()) {
  46. PyErr_Clear();
  47. return false;
  48. }
  49. value = std::complex<T>((T) result.real, (T) result.imag);
  50. return true;
  51. }
  52. static handle
  53. cast(const std::complex<T> &src, return_value_policy /* policy */, handle /* parent */) {
  54. return PyComplex_FromDoubles((double) src.real(), (double) src.imag());
  55. }
  56. PYBIND11_TYPE_CASTER(std::complex<T>, const_name("complex"));
  57. };
  58. PYBIND11_NAMESPACE_END(detail)
  59. PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
  60. #else
  61. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  62. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)