dlpack.h 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646
  1. #if !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)
  2. /*!
  3. * Copyright (c) 2017 - by Contributors
  4. * \file dlpack.h
  5. * \brief The common header of DLPack.
  6. */
  7. #ifndef DLPACK_DLPACK_H_
  8. #define DLPACK_DLPACK_H_
  9. /**
  10. * \brief Compatibility with C++
  11. */
  12. #ifdef __cplusplus
  13. #define DLPACK_EXTERN_C extern "C"
  14. #else
  15. #define DLPACK_EXTERN_C
  16. #endif
  17. /*! \brief The current major version of dlpack */
  18. #define DLPACK_MAJOR_VERSION 1
  19. /*! \brief The current minor version of dlpack */
  20. #define DLPACK_MINOR_VERSION 3
  21. /*! \brief DLPACK_DLL prefix for windows */
  22. #ifdef _WIN32
  23. #ifdef DLPACK_EXPORTS
  24. #define DLPACK_DLL __declspec(dllexport)
  25. #else
  26. #define DLPACK_DLL __declspec(dllimport)
  27. #endif
  28. #else
  29. #define DLPACK_DLL
  30. #endif
  31. #include <stdint.h>
  32. #include <stddef.h>
  33. #ifdef __cplusplus
  34. extern "C" {
  35. #endif
  36. /*!
  37. * \brief The DLPack version.
  38. *
  39. * A change in major version indicates that we have changed the
  40. * data layout of the ABI - DLManagedTensorVersioned.
  41. *
  42. * A change in minor version indicates that we have added new
  43. * code, such as a new device type, but the ABI is kept the same.
  44. *
  45. * If an obtained DLPack tensor has a major version that disagrees
  46. * with the version number specified in this header file
  47. * (i.e. major != DLPACK_MAJOR_VERSION), the consumer must call the deleter
  48. * (and it is safe to do so). It is not safe to access any other fields
  49. * as the memory layout will have changed.
  50. *
  51. * In the case of a minor version mismatch, the tensor can be safely used as
  52. * long as the consumer knows how to interpret all fields. Minor version
  53. * updates indicate the addition of enumeration values.
  54. */
  55. typedef struct {
  56. /*! \brief DLPack major version. */
  57. uint32_t major;
  58. /*! \brief DLPack minor version. */
  59. uint32_t minor;
  60. } DLPackVersion;
  61. /*!
  62. * \brief The device type in DLDevice.
  63. */
  64. #ifdef __cplusplus
  65. typedef enum : int32_t {
  66. #else
  67. typedef enum {
  68. #endif
  69. /*! \brief CPU device */
  70. kDLCPU = 1,
  71. /*! \brief CUDA GPU device */
  72. kDLCUDA = 2,
  73. /*!
  74. * \brief Pinned CUDA CPU memory by cudaMallocHost
  75. */
  76. kDLCUDAHost = 3,
  77. /*! \brief OpenCL devices. */
  78. kDLOpenCL = 4,
  79. /*! \brief Vulkan buffer for next generation graphics. */
  80. kDLVulkan = 7,
  81. /*! \brief Metal for Apple GPU. */
  82. kDLMetal = 8,
  83. /*! \brief Verilog simulator buffer */
  84. kDLVPI = 9,
  85. /*! \brief ROCm GPUs for AMD GPUs */
  86. kDLROCM = 10,
  87. /*!
  88. * \brief Pinned ROCm CPU memory allocated by hipMallocHost
  89. */
  90. kDLROCMHost = 11,
  91. /*!
  92. * \brief Reserved extension device type,
  93. * used for quickly test extension device
  94. * The semantics can differ depending on the implementation.
  95. */
  96. kDLExtDev = 12,
  97. /*!
  98. * \brief CUDA managed/unified memory allocated by cudaMallocManaged
  99. */
  100. kDLCUDAManaged = 13,
  101. /*!
  102. * \brief Unified shared memory allocated on a oneAPI non-partititioned
  103. * device. Call to oneAPI runtime is required to determine the device
  104. * type, the USM allocation type and the sycl context it is bound to.
  105. *
  106. */
  107. kDLOneAPI = 14,
  108. /*! \brief GPU support for next generation WebGPU standard. */
  109. kDLWebGPU = 15,
  110. /*! \brief Qualcomm Hexagon DSP */
  111. kDLHexagon = 16,
  112. /*! \brief Microsoft MAIA devices */
  113. kDLMAIA = 17,
  114. /*! \brief AWS Trainium */
  115. kDLTrn = 18,
  116. } DLDeviceType;
  117. /*!
  118. * \brief A Device for Tensor and operator.
  119. */
  120. typedef struct {
  121. /*! \brief The device type used in the device. */
  122. DLDeviceType device_type;
  123. /*!
  124. * \brief The device index.
  125. * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0.
  126. */
  127. int32_t device_id;
  128. } DLDevice;
  129. /*!
  130. * \brief The type code options DLDataType.
  131. */
  132. typedef enum {
  133. /*! \brief signed integer */
  134. kDLInt = 0U,
  135. /*! \brief unsigned integer */
  136. kDLUInt = 1U,
  137. /*! \brief IEEE floating point */
  138. kDLFloat = 2U,
  139. /*!
  140. * \brief Opaque handle type, reserved for testing purposes.
  141. * Frameworks need to agree on the handle data type for the exchange to be well-defined.
  142. */
  143. kDLOpaqueHandle = 3U,
  144. /*! \brief bfloat16 */
  145. kDLBfloat = 4U,
  146. /*!
  147. * \brief complex number
  148. * (C/C++/Python layout: compact struct per complex number)
  149. */
  150. kDLComplex = 5U,
  151. /*! \brief boolean */
  152. kDLBool = 6U,
  153. /*! \brief FP8 data types */
  154. kDLFloat8_e3m4 = 7U,
  155. kDLFloat8_e4m3 = 8U,
  156. kDLFloat8_e4m3b11fnuz = 9U,
  157. kDLFloat8_e4m3fn = 10U,
  158. kDLFloat8_e4m3fnuz = 11U,
  159. kDLFloat8_e5m2 = 12U,
  160. kDLFloat8_e5m2fnuz = 13U,
  161. kDLFloat8_e8m0fnu = 14U,
  162. /*! \brief FP6 data types
  163. * Setting bits != 6 is currently unspecified, and the producer must ensure it is set
  164. * while the consumer must stop importing if the value is unexpected.
  165. */
  166. kDLFloat6_e2m3fn = 15U,
  167. kDLFloat6_e3m2fn = 16U,
  168. /*! \brief FP4 data types
  169. * Setting bits != 4 is currently unspecified, and the producer must ensure it is set
  170. * while the consumer must stop importing if the value is unexpected.
  171. */
  172. kDLFloat4_e2m1fn = 17U,
  173. } DLDataTypeCode;
  174. /*!
  175. * \brief The data type the tensor can hold. The data type is assumed to follow the
  176. * native endian-ness. An explicit error message should be raised when attempting to
  177. * export an array with non-native endianness
  178. *
  179. * Examples
  180. * - float: type_code = 2, bits = 32, lanes = 1
  181. * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4
  182. * - int8: type_code = 0, bits = 8, lanes = 1
  183. * - std::complex<float>: type_code = 5, bits = 64, lanes = 1
  184. * - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits)
  185. * - float8_e4m3: type_code = 8, bits = 8, lanes = 1 (packed in memory)
  186. * - float6_e3m2fn: type_code = 16, bits = 6, lanes = 1 (packed in memory)
  187. * - float4_e2m1fn: type_code = 17, bits = 4, lanes = 1 (packed in memory)
  188. *
  189. * When a sub-byte type is packed, DLPack requires the data to be in little bit-endian, i.e.,
  190. * for a packed data set D ((D >> (i * bits)) && bit_mask) stores the i-th element.
  191. */
  192. typedef struct {
  193. /*!
  194. * \brief Type code of base types.
  195. * We keep it uint8_t instead of DLDataTypeCode for minimal memory
  196. * footprint, but the value should be one of DLDataTypeCode enum values.
  197. * */
  198. uint8_t code;
  199. /*!
  200. * \brief Number of bits, common choices are 8, 16, 32.
  201. */
  202. uint8_t bits;
  203. /*! \brief Number of lanes in the type, used for vector types. */
  204. uint16_t lanes;
  205. } DLDataType;
  206. /*!
  207. * \brief Plain C Tensor object, does not manage memory.
  208. */
  209. typedef struct {
  210. /*!
  211. * \brief The data pointer points to the allocated data. This will be CUDA
  212. * device pointer or cl_mem handle in OpenCL. It may be opaque on some device
  213. * types. This pointer is always aligned to 256 bytes as in CUDA. The
  214. * `byte_offset` field should be used to point to the beginning of the data.
  215. *
  216. * Note that as of Nov 2021, multiple libraries (CuPy, PyTorch, TensorFlow,
  217. * TVM, perhaps others) do not adhere to this 256 byte alignment requirement
  218. * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed
  219. * (after which this note will be updated); at the moment it is recommended
  220. * to not rely on the data pointer being correctly aligned.
  221. *
  222. * For given DLTensor, the size of memory required to store the contents of
  223. * data is calculated as follows:
  224. *
  225. * \code{.c}
  226. * static inline size_t GetDataSize(const DLTensor* t) {
  227. * size_t size = 1;
  228. * for (tvm_index_t i = 0; i < t->ndim; ++i) {
  229. * size *= t->shape[i];
  230. * }
  231. * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
  232. * return size;
  233. * }
  234. * \endcode
  235. *
  236. * Note that if the tensor is of size zero, then the data pointer should be
  237. * set to `NULL`.
  238. */
  239. void* data;
  240. /*! \brief The device of the tensor */
  241. DLDevice device;
  242. /*! \brief Number of dimensions */
  243. int32_t ndim;
  244. /*! \brief The data type of the pointer*/
  245. DLDataType dtype;
  246. /*!
  247. * \brief The shape of the tensor
  248. *
  249. * When ndim == 0, shape can be set to NULL.
  250. */
  251. int64_t* shape;
  252. /*!
  253. * \brief strides of the tensor (in number of elements, not bytes),
  254. * can not be NULL if ndim != 0, must points to
  255. * an array of ndim elements that specifies the strides,
  256. * so consumer can always rely on strides[dim] being valid for 0 <= dim < ndim.
  257. *
  258. * When ndim == 0, strides can be set to NULL.
  259. *
  260. * \note Before DLPack v1.2, strides can be NULL to indicate contiguous data.
  261. * This is not allowed in DLPack v1.2 and later. The rationale
  262. * is to simplify the consumer handling.
  263. */
  264. int64_t* strides;
  265. /*! \brief The offset in bytes to the beginning pointer to data */
  266. uint64_t byte_offset;
  267. } DLTensor;
  268. /*!
  269. * \brief C Tensor object, manage memory of DLTensor. This data structure is
  270. * intended to facilitate the borrowing of DLTensor by another framework. It is
  271. * not meant to transfer the tensor. When the borrowing framework doesn't need
  272. * the tensor, it should call the deleter to notify the host that the resource
  273. * is no longer needed.
  274. *
  275. * \note This data structure is used as Legacy DLManagedTensor
  276. * in DLPack exchange and is deprecated after DLPack v0.8
  277. * Use DLManagedTensorVersioned instead.
  278. * This data structure may get renamed or deleted in future versions.
  279. *
  280. * \sa DLManagedTensorVersioned
  281. */
  282. typedef struct DLManagedTensor {
  283. /*! \brief DLTensor which is being memory managed */
  284. DLTensor dl_tensor;
  285. /*! \brief the context of the original host framework of DLManagedTensor in
  286. * which DLManagedTensor is used in the framework. It can also be NULL.
  287. */
  288. void * manager_ctx;
  289. /*!
  290. * \brief Destructor - this should be called
  291. * to destruct the manager_ctx which backs the DLManagedTensor. It can be
  292. * NULL if there is no way for the caller to provide a reasonable destructor.
  293. * The destructor deletes the argument self as well.
  294. */
  295. void (*deleter)(struct DLManagedTensor * self);
  296. } DLManagedTensor;
  297. // bit masks used in the DLManagedTensorVersioned
  298. /*! \brief bit mask to indicate that the tensor is read only. */
  299. #define DLPACK_FLAG_BITMASK_READ_ONLY (1UL << 0UL)
  300. /*!
  301. * \brief bit mask to indicate that the tensor is a copy made by the producer.
  302. *
  303. * If set, the tensor is considered solely owned throughout its lifetime by the
  304. * consumer, until the producer-provided deleter is invoked.
  305. */
  306. #define DLPACK_FLAG_BITMASK_IS_COPIED (1UL << 1UL)
  307. /*!
  308. * \brief bit mask to indicate that whether a sub-byte type is packed or padded.
  309. *
  310. * The default for sub-byte types (ex: fp4/fp6) is assumed packed. This flag can
  311. * be set by the producer to signal that a tensor of sub-byte type is padded.
  312. */
  313. #define DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED (1UL << 2UL)
  314. /*!
  315. * \brief A versioned and managed C Tensor object, manage memory of DLTensor.
  316. *
  317. * This data structure is intended to facilitate the borrowing of DLTensor by
  318. * another framework. It is not meant to transfer the tensor. When the borrowing
  319. * framework doesn't need the tensor, it should call the deleter to notify the
  320. * host that the resource is no longer needed.
  321. *
  322. * \note This is the current standard DLPack exchange data structure.
  323. */
  324. typedef struct DLManagedTensorVersioned {
  325. /*!
  326. * \brief The API and ABI version of the current managed Tensor
  327. */
  328. DLPackVersion version;
  329. /*!
  330. * \brief the context of the original host framework.
  331. *
  332. * Stores DLManagedTensorVersioned is used in the
  333. * framework. It can also be NULL.
  334. */
  335. void *manager_ctx;
  336. /*!
  337. * \brief Destructor.
  338. *
  339. * This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned.
  340. * It can be NULL if there is no way for the caller to provide a reasonable
  341. * destructor. The destructor deletes the argument self as well.
  342. */
  343. void (*deleter)(struct DLManagedTensorVersioned *self);
  344. /*!
  345. * \brief Additional bitmask flags information about the tensor.
  346. *
  347. * By default the flags should be set to 0.
  348. *
  349. * \note Future ABI changes should keep everything until this field
  350. * stable, to ensure that deleter can be correctly called.
  351. *
  352. * \sa DLPACK_FLAG_BITMASK_READ_ONLY
  353. * \sa DLPACK_FLAG_BITMASK_IS_COPIED
  354. */
  355. uint64_t flags;
  356. /*! \brief DLTensor which is being memory managed */
  357. DLTensor dl_tensor;
  358. } DLManagedTensorVersioned;
  359. //----------------------------------------------------------------------
  360. // DLPack `__dlpack_c_exchange_api__` fast exchange protocol definitions
  361. //----------------------------------------------------------------------
  362. /*!
  363. * \brief Request a producer library to create a new tensor.
  364. *
  365. * Create a new `DLManagedTensorVersioned` within the context of the producer
  366. * library. The allocation is defined via the prototype DLTensor.
  367. *
  368. * This function is exposed by the framework through the DLPackExchangeAPI.
  369. *
  370. * \param prototype The prototype DLTensor. Only the dtype, ndim, shape,
  371. * and device fields are used.
  372. * \param out The output DLManagedTensorVersioned.
  373. * \param error_ctx Context for `SetError`.
  374. * \param SetError The function to set the error.
  375. * \return The owning DLManagedTensorVersioned* or NULL on failure.
  376. * SetError is called exactly when NULL is returned (the implementer
  377. * must ensure this).
  378. * \note - As a C function, must not thrown C++ exceptions.
  379. * - Error propagation via SetError to avoid any direct need
  380. * of Python API. Due to this `SetError` may have to ensure the GIL is
  381. * held since it will presumably set a Python error.
  382. *
  383. * \sa DLPackExchangeAPI
  384. */
  385. typedef int (*DLPackManagedTensorAllocator)( //
  386. DLTensor* prototype, DLManagedTensorVersioned** out, void* error_ctx, //
  387. void (*SetError)(void* error_ctx, const char* kind, const char* message) //
  388. );
  389. /*!
  390. * \brief Exports a PyObject* Tensor/NDArray to a DLManagedTensorVersioned.
  391. *
  392. * This function does not perform any stream synchronization. The consumer should query
  393. * DLPackCurrentWorkStream to get the current work stream and launch kernels on it.
  394. *
  395. * This function is exposed by the framework through the DLPackExchangeAPI.
  396. *
  397. * \param py_object The Python object to convert. Must have the same type
  398. * as the one the `DLPackExchangeAPI` was discovered from.
  399. * \return The owning DLManagedTensorVersioned* or NULL on failure with a
  400. * Python exception set. If the data cannot be described using DLPack
  401. * this should be a BufferError if possible.
  402. * \note - As a C function, must not thrown C++ exceptions.
  403. *
  404. * \sa DLPackExchangeAPI, DLPackCurrentWorkStream
  405. */
  406. typedef int (*DLPackManagedTensorFromPyObjectNoSync)( //
  407. void* py_object, //
  408. DLManagedTensorVersioned** out //
  409. );
  410. /*!
  411. * \brief Exports a PyObject* Tensor/NDArray to a provided DLTensor.
  412. *
  413. * This function provides a faster interface for temporary, non-owning,
  414. * exchange. The producer (implementer) still owns the memory of data, strides,
  415. * shape. The liveness of the DLTensor and the data it views is only guaranteed
  416. * until control is returned.
  417. *
  418. * This function currently assumes that the producer (implementer) can fill
  419. * in the DLTensor shape and strides without the need for temporary allocations.
  420. *
  421. * This function does not perform any stream synchronization. The consumer
  422. * should query DLPackCurrentWorkStream to get the current work stream and
  423. * launch kernels on it.
  424. *
  425. * This function is exposed by the framework through the DLPackExchangeAPI.
  426. *
  427. * \param py_object The Python object to convert. Must have the same type
  428. * as the one the `DLPackExchangeAPI` was discovered from.
  429. * \param out The output DLTensor, whose space is pre-allocated on stack.
  430. * \return 0 on success, -1 on failure with a Python exception set.
  431. * \note - As a C function, must not thrown C++ exceptions.
  432. *
  433. * \sa DLPackExchangeAPI, DLPackCurrentWorkStream
  434. */
  435. typedef int (*DLPackDLTensorFromPyObjectNoSync)( //
  436. void* py_object, //
  437. DLTensor* out //
  438. );
  439. /*!
  440. * \brief Obtain the current work stream of a device.
  441. *
  442. * Obtain the current work stream of a device from the producer framework.
  443. * For example, it should map to torch.cuda.current_stream in PyTorch.
  444. *
  445. * When device_type is kDLCPU, the consumer do not have to query the stream
  446. * and the producer can simply return NULL when queried.
  447. * The consumer do not have to do anything on stream sync or setting.
  448. * So CPU only framework can just provide a dummy implementation that
  449. * always set out_current_stream[0] to NULL.
  450. *
  451. * \param device_type The device type.
  452. * \param device_id The device id.
  453. * \param out_current_stream The output current work stream.
  454. *
  455. * \return 0 on success, -1 on failure with a Python exception set.
  456. * \note - As a C function, must not thrown C++ exceptions.
  457. *
  458. * \sa DLPackExchangeAPI
  459. */
  460. typedef int (*DLPackCurrentWorkStream)( //
  461. DLDeviceType device_type, //
  462. int32_t device_id, //
  463. void** out_current_stream //
  464. );
  465. /*!
  466. * \brief Imports a DLManagedTensorVersioned to a PyObject* Tensor/NDArray.
  467. *
  468. * Convert an owning DLManagedTensorVersioned* to the Python tensor of the
  469. * producer (implementer) library with the correct type.
  470. *
  471. * This function does not perform any stream synchronization.
  472. *
  473. * This function is exposed by the framework through the DLPackExchangeAPI.
  474. *
  475. * \param tensor The DLManagedTensorVersioned to convert the ownership of the
  476. * tensor is stolen.
  477. * \param out_py_object The output Python object.
  478. * \return 0 on success, -1 on failure with a Python exception set.
  479. *
  480. * \sa DLPackExchangeAPI
  481. */
  482. typedef int (*DLPackManagedTensorToPyObjectNoSync)( //
  483. DLManagedTensorVersioned* tensor, //
  484. void** out_py_object //
  485. );
  486. /*!
  487. * \brief DLPackExchangeAPI stable header.
  488. * \sa DLPackExchangeAPI
  489. */
  490. typedef struct DLPackExchangeAPIHeader {
  491. /*!
  492. * \brief The provided DLPack version the consumer must check major version
  493. * compatibility before using this struct.
  494. */
  495. DLPackVersion version;
  496. /*!
  497. * \brief Optional pointer to an older DLPackExchangeAPI in the chain.
  498. *
  499. * It must be NULL if the framework does not support older versions.
  500. * If the current major version is larger than the one supported by the
  501. * consumer, the consumer may walk this to find an earlier supported version.
  502. *
  503. * \sa DLPackExchangeAPI
  504. */
  505. struct DLPackExchangeAPIHeader* prev_api;
  506. } DLPackExchangeAPIHeader;
  507. /*!
  508. * \brief Framework-specific function pointers table for DLPack exchange.
  509. *
  510. * Additionally to `__dlpack__()` we define a C function table sharable by
  511. *
  512. * Python implementations via `__dlpack_c_exchange_api__`.
  513. * This attribute must be set on the type as a Python PyCapsule
  514. * with name "dlpack_exchange_api".
  515. *
  516. * A consumer library may use a pattern such as:
  517. *
  518. * \code
  519. *
  520. * PyObject *api_obj = type(tensor_obj).__dlpack_c_exchange_api__; // as C-code
  521. * MyDLPackExchangeAPI *api = PyCapsule_GetPointer(api_obj, "dlpack_exchange_api");
  522. * if (api == NULL && PyErr_Occurred()) { goto handle_error; }
  523. *
  524. * \endcode
  525. *
  526. * Note that this must be defined on the type. The consumer should look up the
  527. * attribute on the type and may cache the result for each unique type.
  528. *
  529. * The precise API table is given by:
  530. * \code
  531. * struct MyDLPackExchangeAPI : public DLPackExchangeAPI {
  532. * MyDLPackExchangeAPI() {
  533. * header.version.major = DLPACK_MAJOR_VERSION;
  534. * header.version.minor = DLPACK_MINOR_VERSION;
  535. * header.prev_version_api = nullptr;
  536. *
  537. * managed_tensor_allocator = MyDLPackManagedTensorAllocator;
  538. * managed_tensor_from_py_object_no_sync = MyDLPackManagedTensorFromPyObjectNoSync;
  539. * managed_tensor_to_py_object_no_sync = MyDLPackManagedTensorToPyObjectNoSync;
  540. * dltensor_from_py_object_no_sync = MyDLPackDLTensorFromPyObjectNoSync;
  541. * current_work_stream = MyDLPackCurrentWorkStream;
  542. * }
  543. *
  544. * static const DLPackExchangeAPI* Global() {
  545. * static MyDLPackExchangeAPI inst;
  546. * return &inst;
  547. * }
  548. * };
  549. * \endcode
  550. *
  551. * Guidelines for leveraging DLPackExchangeAPI:
  552. *
  553. * There are generally two kinds of consumer needs for DLPack exchange:
  554. * - N0: library support, where consumer.kernel(x, y, z) would like to run a kernel
  555. * with the data from x, y, z. The consumer is also expected to run the kernel with the same
  556. * stream context as the producer. For example, when x, y, z is torch.Tensor,
  557. * consumer should query exchange_api->current_work_stream to get the
  558. * current stream and launch the kernel with the same stream.
  559. * This setup is necessary for no synchronization in kernel launch and maximum compatibility
  560. * with CUDA graph capture in the producer.
  561. * This is the desirable behavior for library extension support for frameworks like PyTorch.
  562. * - N1: data ingestion and retention
  563. *
  564. * Note that obj.__dlpack__() API should provide useful ways for N1.
  565. * The primary focus of the current DLPackExchangeAPI is to enable faster exchange N0
  566. * with the support of the function pointer current_work_stream.
  567. *
  568. * Array/Tensor libraries should statically create and initialize this structure
  569. * then return a pointer to DLPackExchangeAPI as an int value in Tensor/Array.
  570. * The DLPackExchangeAPI* must stay alive throughout the lifetime of the process.
  571. *
  572. * One simple way to do so is to create a static instance of DLPackExchangeAPI
  573. * within the framework and return a pointer to it. The following code
  574. * shows an example to do so in C++. It should also be reasonably easy
  575. * to do so in other languages.
  576. */
  577. typedef struct DLPackExchangeAPI {
  578. /*!
  579. * \brief The header that remains stable across versions.
  580. */
  581. DLPackExchangeAPIHeader header;
  582. /*!
  583. * \brief Producer function pointer for DLPackManagedTensorAllocator
  584. * This function must not be NULL.
  585. * \sa DLPackManagedTensorAllocator
  586. */
  587. DLPackManagedTensorAllocator managed_tensor_allocator;
  588. /*!
  589. * \brief Producer function pointer for DLPackManagedTensorFromPyObject
  590. * This function must be not NULL.
  591. * \sa DLPackManagedTensorFromPyObject
  592. */
  593. DLPackManagedTensorFromPyObjectNoSync managed_tensor_from_py_object_no_sync;
  594. /*!
  595. * \brief Producer function pointer for DLPackManagedTensorToPyObject
  596. * This function must be not NULL.
  597. * \sa DLPackManagedTensorToPyObject
  598. */
  599. DLPackManagedTensorToPyObjectNoSync managed_tensor_to_py_object_no_sync;
  600. /*!
  601. * \brief Producer function pointer for DLPackDLTensorFromPyObject
  602. * This function can be NULL when the producer does not support this function.
  603. * \sa DLPackDLTensorFromPyObjectNoSync
  604. */
  605. DLPackDLTensorFromPyObjectNoSync dltensor_from_py_object_no_sync;
  606. /*!
  607. * \brief Producer function pointer for DLPackCurrentWorkStream
  608. * This function must be not NULL.
  609. * \sa DLPackCurrentWorkStream
  610. */
  611. DLPackCurrentWorkStream current_work_stream;
  612. } DLPackExchangeAPI;
  613. #ifdef __cplusplus
  614. } // DLPACK_EXTERN_C
  615. #endif
  616. #endif // DLPACK_DLPACK_H_
  617. #else
  618. #error "This file should not be included when either TORCH_STABLE_ONLY or TORCH_TARGET_VERSION is defined."
  619. #endif // !defined(TORCH_STABLE_ONLY) && !defined(TORCH_TARGET_VERSION)