| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
- #pragma once
- #include <ATen/ATen.h>
- #include <ATen/Tensor.h>
- #include <ATen/dlpack.h>
- // this converter will:
- // 1) take a Tensor object and wrap it in the DLPack tensor
- // 2) take a dlpack tensor and convert it to the ATen Tensor
- namespace at {
- TORCH_API ScalarType toScalarType(const DLDataType& dtype);
- TORCH_API DLManagedTensor* toDLPack(const Tensor& src);
- TORCH_API struct DLManagedTensorVersioned* toDLPackVersioned(const Tensor& src);
- TORCH_API void toDLPackNonOwning(const Tensor& src, DLTensor* out);
- TORCH_API Tensor
- fromDLPack(DLManagedTensor* src, std::function<void(void*)> deleter = {});
- TORCH_API Tensor fromDLPackVersioned(
- DLManagedTensorVersioned* src,
- std::function<void(void*)> deleter = {});
- TORCH_API DLDataType getDLDataType(const Tensor& t);
- TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id);
- // Copies the Tensor if there's a device mismatch or copy is forced.
- // This should be used before actually creating the DLPack capsule.
- TORCH_API Tensor maybeCopyTensor(
- const Tensor& data,
- std::optional<DLDevice> optional_dl_device,
- std::optional<bool> copy);
- // Converts the given at::Device into a DLDevice.
- TORCH_API DLDevice torchDeviceToDLDevice(at::Device device);
- // Converts the DLDevice to an ATen device.
- TORCH_API Device dlDeviceToTorchDevice(
- DLDeviceType type,
- c10::DeviceIndex index,
- void* data = nullptr);
- // This trait class is used for retrieving different attributes, such as the
- // PyCapsule names and conversion functions for both DLPack tensor classes:
- // `DLManagedTensor` and `DLManagedTensorVersioned`.
- //
- // Each specialization should contain the following 2 traits:
- // - `capsule`: actual name of the capsule
- // - `used`: name of the capsule after using it
- // - `toDLPack`: function for converting a tensor into a DLPack capsule
- // - `fromDLPack`: function for creating a tensor from a DLPack capsule
- //
- // While `toDLPack` is the directly exposed to Python, `fromDLPack` is not.
- // Although it contains the core implementation, it lacks the required book
- // keeping logic contained in its caller `tensor_fromDLPack`.
- //
- // That said, `fromDLPack` is used directly in a few DLPack tests that live
- // inside ATen (no Python available).
- template <class T>
- struct DLPackTraits {};
- template <>
- struct DLPackTraits<DLManagedTensor> {
- inline static constexpr const char* capsule = "dltensor";
- inline static constexpr const char* used = "used_dltensor";
- inline static auto toDLPack = at::toDLPack;
- inline static auto fromDLPack = at::fromDLPack;
- };
- template <>
- struct DLPackTraits<DLManagedTensorVersioned> {
- inline static constexpr const char* capsule = "dltensor_versioned";
- inline static constexpr const char* used = "used_dltensor_versioned";
- inline static auto toDLPack = at::toDLPackVersioned;
- inline static auto fromDLPack = at::fromDLPackVersioned;
- };
- } // namespace at
- #else
- #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
- #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
|