FunctionRef.h 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. //===- llvm/ADT/STLExtras.h - Useful STL related functions ------*- C++ -*-===//
  3. //
  4. // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
  5. // See https://llvm.org/LICENSE.txt for license information.
  6. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  7. //
  8. //===----------------------------------------------------------------------===//
  9. //
  10. // This file contains some templates that are useful if you are working with the
  11. // STL at all.
  12. //
  13. // No library is required when using these functions.
  14. //
  15. //===----------------------------------------------------------------------===//
  16. // c10: modified from llvm::function_ref
  17. // c10: added more SFINAE to enable use in overloaded functions
  18. #pragma once
  19. #include <cstdint>
  20. #include <type_traits>
  21. #include <utility>
  22. namespace c10 {
  23. /// An efficient, type-erasing, non-owning reference to a callable. This is
  24. /// intended for use as the type of a function parameter that is not used
  25. /// after the function in question returns.
  26. ///
  27. /// This class does not own the callable, so it is not in general safe to store
  28. /// a function_ref.
  29. template <typename Fn>
  30. class function_ref;
  31. template <typename Ret, typename... Params>
  32. class function_ref<Ret(Params...)> {
  33. Ret (*callback)(intptr_t callable, Params... params) = nullptr;
  34. intptr_t callable{};
  35. template <typename Callable>
  36. static Ret callback_fn(intptr_t callable, Params... params) {
  37. return (*reinterpret_cast<Callable*>(callable))(
  38. std::forward<Params>(params)...);
  39. }
  40. public:
  41. function_ref() = default;
  42. function_ref(std::nullptr_t) {}
  43. template <typename Callable>
  44. function_ref(
  45. // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
  46. Callable&& callable,
  47. std::enable_if_t<!std::is_same_v<
  48. std::remove_reference_t<Callable>,
  49. function_ref>>* /*unused*/
  50. = nullptr,
  51. std::enable_if_t<std::is_convertible_v<
  52. typename std::invoke_result_t<Callable, Params...>,
  53. Ret>>* /*unused*/
  54. = nullptr)
  55. : callback(callback_fn<std::remove_reference_t<Callable>>),
  56. callable(reinterpret_cast<intptr_t>(&callable)) {}
  57. Ret operator()(Params... params) const {
  58. return callback(callable, std::forward<Params>(params)...);
  59. }
  60. operator bool() const {
  61. return callback;
  62. }
  63. };
  64. } // namespace c10
  65. #else
  66. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  67. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)