FunctionTraits.h 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <cstddef>
  4. #include <tuple>
  5. // Modified from https://stackoverflow.com/questions/7943525/is-it-possible-to-figure-out-the-parameter-type-and-return-type-of-a-lambda
  6. // Fallback, anything with an operator()
  7. template <typename T>
  8. struct function_traits : public function_traits<decltype(&T::operator())> {
  9. };
  10. // Pointers to class members that are themselves functors.
  11. // For example, in the following code:
  12. // template <typename func_t>
  13. // struct S {
  14. // func_t f;
  15. // };
  16. // template <typename func_t>
  17. // S<func_t> make_s(func_t f) {
  18. // return S<func_t> { .f = f };
  19. // }
  20. //
  21. // auto s = make_s([] (int, float) -> double { /* ... */ });
  22. //
  23. // function_traits<decltype(&s::f)> traits;
  24. template <typename ClassType, typename T>
  25. struct function_traits<T ClassType::*> : public function_traits<T> {
  26. };
  27. // Const class member functions
  28. template <typename ClassType, typename ReturnType, typename... Args>
  29. struct function_traits<ReturnType(ClassType::*)(Args...) const> : public function_traits<ReturnType(Args...)> {
  30. };
  31. // Reference types
  32. template <typename T>
  33. struct function_traits<T&> : public function_traits<T> {};
  34. template <typename T>
  35. struct function_traits<T*> : public function_traits<T> {};
  36. // Free functions
  37. template <typename ReturnType, typename... Args>
  38. struct function_traits<ReturnType(Args...)> {
  39. // arity is the number of arguments.
  40. enum { arity = sizeof...(Args) };
  41. using ArgsTuple = std::tuple<Args...>;
  42. using result_type = ReturnType;
  43. template <size_t i>
  44. struct arg
  45. {
  46. using type = std::tuple_element_t<i, std::tuple<Args...>>;
  47. // the i-th argument is equivalent to the i-th tuple element of a tuple
  48. // composed of those arguments.
  49. };
  50. };
  51. template <typename T>
  52. struct nullary_function_traits {
  53. using traits = function_traits<T>;
  54. using result_type = typename traits::result_type;
  55. };
  56. template <typename T>
  57. struct unary_function_traits {
  58. using traits = function_traits<T>;
  59. using result_type = typename traits::result_type;
  60. using arg1_t = typename traits::template arg<0>::type;
  61. };
  62. template <typename T>
  63. struct binary_function_traits {
  64. using traits = function_traits<T>;
  65. using result_type = typename traits::result_type;
  66. using arg1_t = typename traits::template arg<0>::type;
  67. using arg2_t = typename traits::template arg<1>::type;
  68. };
  69. // Traits for calling with c10::guts::invoke, where member_functions have a first argument of ClassType
  70. template <typename T>
  71. struct invoke_traits : public function_traits<T>{
  72. };
  73. template <typename T>
  74. struct invoke_traits<T&> : public invoke_traits<T>{
  75. };
  76. template <typename T>
  77. struct invoke_traits<T&&> : public invoke_traits<T>{
  78. };
  79. template <typename ClassType, typename ReturnType, typename... Args>
  80. struct invoke_traits<ReturnType(ClassType::*)(Args...)> :
  81. public function_traits<ReturnType(ClassType&, Args...)> {
  82. };
  83. template <typename ClassType, typename ReturnType, typename... Args>
  84. struct invoke_traits<ReturnType(ClassType::*)(Args...) const> :
  85. public function_traits<ReturnType(const ClassType&, Args...)> {
  86. };
  87. #else
  88. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  89. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)