DLConvertor.h 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. #pragma once
  3. #include <ATen/ATen.h>
  4. #include <ATen/Tensor.h>
  5. #include <ATen/dlpack.h>
  6. // this converter will:
  7. // 1) take a Tensor object and wrap it in the DLPack tensor
  8. // 2) take a dlpack tensor and convert it to the ATen Tensor
  9. namespace at {
  10. TORCH_API ScalarType toScalarType(const DLDataType& dtype);
  11. TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
  12. TORCH_API struct DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src);
  13. TORCH_API void toDLPackNonOwning(const Tensor& src, DLTensor* out);
  14. TORCH_API Tensor
  15. fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter = {});
  16. TORCH_API Tensor fromDLPackVersioned(
  17. DLManagedTensorVersioned* src,
  18. std::function<void(void*)> deleter = {});
  19. TORCH_API DLDataType getDLDataType(const Tensor& t);
  20. TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
  21. // Copies the Tensor if there's a device mismatch or copy is forced.
  22. // This should be used before actually creating the DLPack capsule.
  23. TORCH_API Tensor maybeCopyTensor(
  24. const Tensor& data,
  25. std::optional<DLDevice> optional_dl_device,
  26. std::optional<bool> copy);
  27. // Converts the given at::Device into a DLDevice.
  28. TORCH_API DLDevice torchDeviceToDLDevice(at::Device device);
  29. // Converts the DLDevice to an ATen device.
  30. TORCH_API Device dlDeviceToTorchDevice(
  31. DLDeviceType type,
  32. c10::DeviceIndex index,
  33. void* data = nullptr);
  34. // This trait class is used for retrieving different attributes, such as the
  35. // PyCapsule names and conversion functions for both DLPack tensor classes:
  36. // `DLManagedTensor` and `DLManagedTensorVersioned`.
  37. //
  38. // Each specialization should contain the following 2 traits:
  39. // - `capsule`: actual name of the capsule
  40. // - `used`: name of the capsule after using it
  41. // - `toDLPack`: function for converting a tensor into a DLPack capsule
  42. // - `fromDLPack`: function for creating a tensor from a DLPack capsule
  43. //
  44. // While `toDLPack` is the directly exposed to Python, `fromDLPack` is not.
  45. // Although it contains the core implementation, it lacks the required book
  46. // keeping logic contained in its caller `tensor_fromDLPack`.
  47. //
  48. // That said, `fromDLPack` is used directly in a few DLPack tests that live
  49. // inside ATen (no Python available).
  50. template <class T>
  51. struct DLPackTraits {};
  52. template <>
  53. struct DLPackTraits<DLManagedTensor> {
  54. inline static constexpr const char* capsule = "dltensor";
  55. inline static constexpr const char* used = "used_dltensor";
  56. inline static auto toDLPack = at::toDLPack;
  57. inline static auto fromDLPack = at::fromDLPack;
  58. };
  59. template <>
  60. struct DLPackTraits<DLManagedTensorVersioned> {
  61. inline static constexpr const char* capsule = "dltensor_versioned";
  62. inline static constexpr const char* used = "used_dltensor_versioned";
  63. inline static auto toDLPack = at::toDLPackVersioned;
  64. inline static auto fromDLPack = at::fromDLPackVersioned;
  65. };
  66. } // namespace at
  67. #else
  68. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  69. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)