PrivateUse1HooksInterface.h 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/core/GeneratorForPrivateuseone.h>
  4. #include <ATen/detail/AcceleratorHooksInterface.h>
  5. #include <c10/core/Allocator.h>
  6. #include <c10/core/Device.h>
  7. #include <c10/core/Storage.h>
  8. #include <c10/util/Exception.h>
  9. C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter")
  10. namespace at {
  11. struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface {
  12. #define FAIL_PRIVATEUSE1HOOKS_FUNC(func) \
  13. TORCH_CHECK_NOT_IMPLEMENTED( \
  14. false, \
  15. "You should register `PrivateUse1HooksInterface`", \
  16. "by `RegisterPrivateUse1HooksInterface` and implement `", \
  17. func, \
  18. "` at the same time for PrivateUse1.");
  19. ~PrivateUse1HooksInterface() override = default;
  20. bool isBuilt() const override {
  21. FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
  22. }
  23. bool isAvailable() const override {
  24. FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
  25. }
  26. const at::Generator& getDefaultGenerator(
  27. c10::DeviceIndex device_index) const override {
  28. FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
  29. }
  30. Generator getNewGenerator(
  31. [[maybe_unused]] DeviceIndex device_index = -1) const override {
  32. // TODO(FFFrog): Preserved for BC and will be removed in the future.
  33. if (at::GetGeneratorPrivate().has_value())
  34. return at::GetGeneratorForPrivateuse1(device_index);
  35. FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
  36. }
  37. at::Device getDeviceFromPtr(void* data) const override {
  38. FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
  39. }
  40. bool isPinnedPtr(const void* data) const override {
  41. return false;
  42. }
  43. Allocator* getPinnedMemoryAllocator() const override {
  44. FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
  45. }
  46. bool hasPrimaryContext(DeviceIndex device_index) const override {
  47. FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
  48. }
  49. void init() const override {}
  50. virtual void resizePrivateUse1Bytes(
  51. const c10::Storage& storage,
  52. size_t newsize) const {
  53. FAIL_PRIVATEUSE1HOOKS_FUNC(__func__);
  54. }
  55. #undef FAIL_PRIVATEUSE1HOOKS_FUNC
  56. };
  57. struct TORCH_API PrivateUse1HooksArgs {};
  58. TORCH_API void RegisterPrivateUse1HooksInterface(
  59. at::PrivateUse1HooksInterface* hook_);
  60. TORCH_API bool isPrivateUse1HooksRegistered();
  61. namespace detail {
  62. TORCH_API const at::PrivateUse1HooksInterface& getPrivateUse1Hooks();
  63. } // namespace detail
  64. } // namespace at
  65. C10_DIAGNOSTIC_POP()
  66. #else
  67. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  68. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)