functional.h 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. /*
  3. pybind11/functional.h: std::function<> support
  4. Copyright (c) 2016 Wenzel Jakob <wenzel.jakob@epfl.ch>
  5. All rights reserved. Use of this source code is governed by a
  6. BSD-style license that can be found in the LICENSE file.
  7. */
  8. #pragma once
  9. #include "pybind11.h"
  10. #include <functional>
  11. PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
  12. PYBIND11_NAMESPACE_BEGIN(detail)
  13. PYBIND11_NAMESPACE_BEGIN(type_caster_std_function_specializations)
  14. // ensure GIL is held during functor destruction
  15. struct func_handle {
  16. function f;
  17. #if !(defined(_MSC_VER) && _MSC_VER == 1916 && defined(PYBIND11_CPP17))
  18. // This triggers a syntax error under very special conditions (very weird indeed).
  19. explicit
  20. #endif
  21. func_handle(function &&f_) noexcept
  22. : f(std::move(f_)) {
  23. }
  24. func_handle(const func_handle &f_) { operator=(f_); }
  25. func_handle &operator=(const func_handle &f_) {
  26. gil_scoped_acquire acq;
  27. f = f_.f;
  28. return *this;
  29. }
  30. ~func_handle() {
  31. gil_scoped_acquire acq;
  32. function kill_f(std::move(f));
  33. }
  34. };
  35. // to emulate 'move initialization capture' in C++11
  36. struct func_wrapper_base {
  37. func_handle hfunc;
  38. explicit func_wrapper_base(func_handle &&hf) noexcept : hfunc(hf) {}
  39. };
  40. template <typename Return, typename... Args>
  41. struct func_wrapper : func_wrapper_base {
  42. using func_wrapper_base::func_wrapper_base;
  43. Return operator()(Args... args) const { // NOLINT(performance-unnecessary-value-param)
  44. gil_scoped_acquire acq;
  45. // casts the returned object as a rvalue to the return type
  46. return hfunc.f(std::forward<Args>(args)...).template cast<Return>();
  47. }
  48. };
  49. PYBIND11_NAMESPACE_END(type_caster_std_function_specializations)
  50. template <typename Return, typename... Args>
  51. struct type_caster<std::function<Return(Args...)>> {
  52. using type = std::function<Return(Args...)>;
  53. using retval_type = conditional_t<std::is_same<Return, void>::value, void_type, Return>;
  54. using function_type = Return (*)(Args...);
  55. public:
  56. bool load(handle src, bool convert) {
  57. if (src.is_none()) {
  58. // Defer accepting None to other overloads (if we aren't in convert mode):
  59. if (!convert) {
  60. return false;
  61. }
  62. return true;
  63. }
  64. if (!isinstance<function>(src)) {
  65. return false;
  66. }
  67. auto func = reinterpret_borrow<function>(src);
  68. /*
  69. When passing a C++ function as an argument to another C++
  70. function via Python, every function call would normally involve
  71. a full C++ -> Python -> C++ roundtrip, which can be prohibitive.
  72. Here, we try to at least detect the case where the function is
  73. stateless (i.e. function pointer or lambda function without
  74. captured variables), in which case the roundtrip can be avoided.
  75. */
  76. if (auto cfunc = func.cpp_function()) {
  77. auto *cfunc_self = PyCFunction_GET_SELF(cfunc.ptr());
  78. if (cfunc_self == nullptr) {
  79. PyErr_Clear();
  80. } else {
  81. function_record *rec = function_record_ptr_from_PyObject(cfunc_self);
  82. while (rec != nullptr) {
  83. if (rec->is_stateless
  84. && same_type(typeid(function_type),
  85. *reinterpret_cast<const std::type_info *>(rec->data[1]))) {
  86. struct capture {
  87. function_type f;
  88. static capture *from_data(void **data) {
  89. return PYBIND11_STD_LAUNDER(reinterpret_cast<capture *>(data));
  90. }
  91. };
  92. PYBIND11_ENSURE_PRECONDITION_FOR_FUNCTIONAL_H_PERFORMANCE_OPTIMIZATIONS(
  93. std::is_standard_layout<capture>::value);
  94. value = capture::from_data(rec->data)->f;
  95. return true;
  96. }
  97. rec = rec->next;
  98. }
  99. }
  100. // PYPY segfaults here when passing builtin function like sum.
  101. // Raising an fail exception here works to prevent the segfault, but only on gcc.
  102. // See PR #1413 for full details
  103. }
  104. value = type_caster_std_function_specializations::func_wrapper<Return, Args...>(
  105. type_caster_std_function_specializations::func_handle(std::move(func)));
  106. return true;
  107. }
  108. template <typename Func>
  109. static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) {
  110. if (!f_) {
  111. return none().release();
  112. }
  113. auto result = f_.template target<function_type>();
  114. if (result) {
  115. return cpp_function(*result, policy).release();
  116. }
  117. return cpp_function(std::forward<Func>(f_), policy).release();
  118. }
  119. PYBIND11_TYPE_CASTER(
  120. type,
  121. const_name("collections.abc.Callable[[")
  122. + ::pybind11::detail::concat(::pybind11::detail::arg_descr(make_caster<Args>::name)...)
  123. + const_name("], ") + ::pybind11::detail::return_descr(make_caster<retval_type>::name)
  124. + const_name("]"));
  125. };
  126. PYBIND11_NAMESPACE_END(detail)
  127. PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
  128. #else
  129. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  130. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)