driver.c 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. #define __HIP_PLATFORM_AMD__
  2. #include <hip/hip_runtime.h>
  3. #include <hip/hip_runtime_api.h>
  4. #define PY_SSIZE_T_CLEAN
  5. #include <Python.h>
  6. #include <stdbool.h>
  7. #include <stdio.h>
  8. #include <stdlib.h>
  9. #ifdef _WIN32
  10. #include <windows.h>
  11. // Windows compatibility layer for dlopen/dlsym/dlclose/dlerror
  12. #define RTLD_NOW 0
  13. #define RTLD_LAZY 0
  14. #define RTLD_LOCAL 0
  15. static char dlerror_buf[512];
  16. static inline void *dlopen(const char *filename, int flags) {
  17. (void)flags;
  18. HMODULE h = LoadLibraryA(filename);
  19. if (!h) {
  20. snprintf(dlerror_buf, sizeof(dlerror_buf),
  21. "LoadLibrary failed with error %lu", GetLastError());
  22. }
  23. return (void *)h;
  24. }
  25. static inline void *dlsym(void *handle, const char *symbol) {
  26. void *p = (void *)GetProcAddress((HMODULE)handle, symbol);
  27. if (!p) {
  28. snprintf(dlerror_buf, sizeof(dlerror_buf),
  29. "GetProcAddress failed for %s with error %lu", symbol,
  30. GetLastError());
  31. }
  32. return p;
  33. }
  34. static inline int dlclose(void *handle) {
  35. return FreeLibrary((HMODULE)handle) ? 0 : -1;
  36. }
  37. static inline const char *dlerror(void) {
  38. return dlerror_buf[0] ? dlerror_buf : NULL;
  39. }
  40. #else
  41. #include <dlfcn.h>
  42. #endif
  43. typedef struct {
  44. uint32_t group0_0;
  45. uint32_t group0_1;
  46. uint32_t group0_2;
  47. uint32_t group0_3;
  48. uint32_t group1_0;
  49. uint32_t group1_1;
  50. uint32_t group1_2;
  51. uint32_t group1_3;
  52. uint32_t group1_4;
  53. uint32_t group1_5;
  54. uint32_t group1_6;
  55. uint32_t group1_7;
  56. } TDMDescriptor;
  57. typedef struct {
  58. PyObject_HEAD;
  59. TDMDescriptor desc;
  60. } PyTDMDescriptorObject;
  61. static PyObject *PyTDMDescriptor_new(PyTypeObject *type, PyObject *args,
  62. PyObject *kw) {
  63. PyTDMDescriptorObject *self =
  64. (PyTDMDescriptorObject *)type->tp_alloc(type, 0);
  65. if (!self)
  66. return NULL;
  67. memset(&self->desc, 0, sizeof(self->desc));
  68. return (PyObject *)self;
  69. }
  70. static void PyTDMDescriptor_dealloc(PyTDMDescriptorObject *self) {
  71. Py_TYPE(self)->tp_free((PyObject *)self);
  72. }
  73. static PyTypeObject PyTDMDescriptorType = {
  74. PyVarObject_HEAD_INIT(NULL, 0).tp_name =
  75. "triton.backends.amd.PyTDMDescriptor",
  76. .tp_basicsize = sizeof(PyTDMDescriptorObject),
  77. .tp_itemsize = 0,
  78. .tp_flags = Py_TPFLAGS_DEFAULT,
  79. .tp_doc = "PyObject for TDMDescriptor",
  80. .tp_new = PyTDMDescriptor_new,
  81. .tp_dealloc = (destructor)PyTDMDescriptor_dealloc,
  82. };
  83. // TODO: Both host-side and device-side TDM descriptor follow the same encoding
  84. // format. Consider to add a common utility to remove duplicate code.
  85. static bool encodeTDMDescriptor(TDMDescriptor *desc, int elementBitWidth,
  86. uint32_t *blockSize, int numWarps,
  87. int padInterval, int padAmount, uint32_t *shape,
  88. uint32_t *strides, uint64_t globalAddress,
  89. int rank) {
  90. // NYI: TDM > 2D cases
  91. if (rank != 2)
  92. return false;
  93. // Get warp distribution
  94. uint32_t numWarpsDim0 = numWarps;
  95. for (; numWarpsDim0 > blockSize[0]; numWarpsDim0 /= 2)
  96. ;
  97. uint32_t numWarpsDim1 = numWarps / numWarpsDim0;
  98. if (!(numWarpsDim0 > 0 && blockSize[1] % numWarpsDim1 == 0))
  99. return false;
  100. uint32_t blockSize0 = (blockSize[0] + numWarpsDim0 - 1) / numWarpsDim0;
  101. uint32_t blockSize1 = (blockSize[1] + numWarpsDim1 - 1) / numWarpsDim1;
  102. // group0 (128 bits / 4 dwords) effective bit encoding:
  103. // [120:64]: global address
  104. // [127:126]: type - currently always set to 0x2
  105. desc->group0_2 = (uint32_t)(globalAddress & 0xFFFFFFFF);
  106. desc->group0_3 = (uint32_t)((globalAddress >> 32) & 0x01FFFFFF);
  107. desc->group0_3 |= (1U << 31);
  108. // group1 (256 bits / 8 dwords) effective bit encoding:
  109. // [17:16]: data size - log2(element size in bytes)
  110. // [20]: enable padding
  111. // [24:22]: pad interval - log2(pad interval in dwords) - 1
  112. // [31:25]: pad amount - pad amount in dwords - 1
  113. // [79:48]: tensor shape dim inner
  114. // [111:80]: tensor shape dim outer
  115. // [127:112]: block shape dim inner
  116. // [143:128]: block shape dim outer
  117. // [207:160]: tensor stride dim outer (we only use 32 bits)
  118. int elementSizeInBytes = elementBitWidth / 8;
  119. int dataSize = log2(elementSizeInBytes);
  120. desc->group1_0 = (dataSize << 16);
  121. int dwordSize = 32;
  122. int padIntervalInDwords = padInterval * elementBitWidth / dwordSize;
  123. int padAmountInDwords = padAmount * elementBitWidth / dwordSize;
  124. if (padIntervalInDwords > 0 && padAmountInDwords > 0) {
  125. int log2PadInterval = log2(padIntervalInDwords);
  126. desc->group1_0 |= (1 << 20);
  127. desc->group1_0 |= ((log2PadInterval - 1) << 22);
  128. desc->group1_0 |= ((padAmountInDwords - 1) << 25);
  129. }
  130. desc->group1_1 = (shape[1] << 16);
  131. desc->group1_2 = (shape[1] >> 16);
  132. desc->group1_2 |= (shape[0] << 16);
  133. desc->group1_3 = (shape[0] >> 16);
  134. desc->group1_3 |= (blockSize1 << 16);
  135. desc->group1_4 = (blockSize0 & 0xFFFF);
  136. desc->group1_5 = strides[0];
  137. return true;
  138. }
  139. // The list of paths to search for the HIP runtime library. The caller Python
  140. // code should substitute the search path placeholder.
  141. static const char *hipLibSearchPaths[] = {"/*py_libhip_search_path*/"};
  142. // The list of HIP dynamic library symbols and their signature we are interested
  143. // in this file.
  144. // |FOR_EACH_ERR_FN| is a macro to process APIs that return hipError_t;
  145. // |FOR_EACH_STR_FN| is a macro to process APIs that return const char *.
  146. #define HIP_SYMBOL_LIST(FOR_EACH_ERR_FN, FOR_EACH_STR_FN) \
  147. FOR_EACH_STR_FN(hipGetErrorString, hipError_t hipError) \
  148. FOR_EACH_ERR_FN(hipGetDeviceProperties, hipDeviceProp_t *prop, int deviceId) \
  149. FOR_EACH_ERR_FN(hipModuleLoadDataEx, hipModule_t *module, const void *image, \
  150. unsigned int numOptions, hipJitOption *options, \
  151. void **optionValues) \
  152. FOR_EACH_ERR_FN(hipModuleGetFunction, hipFunction_t *function, \
  153. hipModule_t module, const char *kname) \
  154. FOR_EACH_ERR_FN(hipFuncGetAttribute, int *, hipFunction_attribute attr, \
  155. hipFunction_t function)
  156. // HIP driver version format: HIP_VERSION_MAJOR * 10000000 + HIP_VERSION_MINOR *
  157. // 100000 + HIP_VERSION_PATCH.
  158. #define TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version) ((version) / 10000000)
  159. #define TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version) \
  160. (((version) % 10000000) / 100000)
  161. #define TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version) ((version) % 100000)
  162. #define TRITON_HIP_DRIVER_REQ_MAJOR_VERSION (6)
  163. // #define TRITON_HIP_DRIVER_DBG_VERSION
  164. #ifdef TRITON_HIP_DRIVER_DBG_VERSION
  165. #define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
  166. do { \
  167. snprintf(msgBuff, sizeof(msgBuff), "libamdhip64 version is: %d.%d.%d", \
  168. TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(version), \
  169. TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(version), \
  170. TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(version)); \
  171. printf("%s\n", msgBuff); \
  172. } while (0);
  173. #else
  174. #define TRITON_HIP_DRIVER_LOG_VERSION(version, msgBuff) \
  175. do { \
  176. (void)msgBuff; \
  177. (void)(version); \
  178. } while (0);
  179. #endif
  180. #define TRITON_HIP_MSG_BUFF_SIZE (1024U)
  181. // The HIP symbol table for holding resolved dynamic library symbols.
  182. struct HIPSymbolTable {
  183. #define DEFINE_EACH_ERR_FIELD(hipSymbolName, ...) \
  184. hipError_t (*hipSymbolName)(__VA_ARGS__);
  185. #define DEFINE_EACH_STR_FIELD(hipSymbolName, ...) \
  186. const char *(*hipSymbolName)(__VA_ARGS__);
  187. HIP_SYMBOL_LIST(DEFINE_EACH_ERR_FIELD, DEFINE_EACH_STR_FIELD)
  188. };
  189. static struct HIPSymbolTable hipSymbolTable;
  190. static int checkDriverVersion(void *lib) {
  191. int hipVersion = -1;
  192. const char *error = NULL;
  193. typedef hipError_t (*hipDriverGetVersion_fn)(int *driverVersion);
  194. hipDriverGetVersion_fn hipDriverGetVersion;
  195. dlerror(); // Clear existing errors
  196. hipDriverGetVersion =
  197. (hipDriverGetVersion_fn)dlsym(lib, "hipDriverGetVersion");
  198. error = dlerror();
  199. if (error) {
  200. PyErr_SetString(PyExc_RuntimeError,
  201. "cannot query 'hipDriverGetVersion' from libamdhip64.so");
  202. dlclose(lib);
  203. return -1;
  204. }
  205. (void)hipDriverGetVersion(&hipVersion);
  206. char msgBuff[TRITON_HIP_MSG_BUFF_SIZE] = {0};
  207. const int hipMajVersion = TRITON_HIP_DRIVER_EXTRACT_MAJOR_VERSION(hipVersion);
  208. if (hipMajVersion < TRITON_HIP_DRIVER_REQ_MAJOR_VERSION) {
  209. const int hipMinVersion =
  210. TRITON_HIP_DRIVER_EXTRACT_MINOR_VERSION(hipVersion);
  211. const int hipPatchVersion =
  212. TRITON_HIP_DRIVER_EXTRACT_PATCH_VERSION(hipVersion);
  213. snprintf(msgBuff, sizeof(msgBuff),
  214. "libamdhip64 version %d.%d.%d is not supported! Required major "
  215. "version is >=%d.",
  216. hipMajVersion, hipMinVersion, hipPatchVersion,
  217. TRITON_HIP_DRIVER_REQ_MAJOR_VERSION);
  218. PyErr_SetString(PyExc_RuntimeError, msgBuff);
  219. dlclose(lib);
  220. return -1;
  221. }
  222. TRITON_HIP_DRIVER_LOG_VERSION(hipVersion, msgBuff);
  223. return hipVersion;
  224. }
  225. bool initSymbolTable() {
  226. void *lib;
  227. // Go through the list of search paths to dlopen the first HIP driver library.
  228. int n = sizeof(hipLibSearchPaths) / sizeof(hipLibSearchPaths[0]);
  229. for (int i = 0; i < n; ++i) {
  230. void *handle = dlopen(hipLibSearchPaths[i], RTLD_LAZY | RTLD_LOCAL);
  231. if (handle) {
  232. lib = handle;
  233. // printf("[triton] chosen %s\n", hipLibSearchPaths[i]);
  234. }
  235. }
  236. if (!lib) {
  237. PyErr_SetString(PyExc_RuntimeError, "cannot open libamdhip64.so");
  238. return false;
  239. }
  240. int hipVersion = checkDriverVersion(lib);
  241. if (hipVersion == -1)
  242. return false;
  243. const char *error = NULL;
  244. typedef hipError_t (*hipGetProcAddress_fn)(
  245. const char *symbol, void **pfn, int hipVersion, uint64_t hipFlags,
  246. hipDriverProcAddressQueryResult *symbolStatus);
  247. hipGetProcAddress_fn hipGetProcAddress;
  248. dlerror(); // Clear existing errors
  249. *(void **)&hipGetProcAddress = dlsym(lib, "hipGetProcAddress");
  250. error = dlerror();
  251. if (error) {
  252. PyErr_SetString(PyExc_RuntimeError,
  253. "cannot query 'hipGetProcAddress' from libamdhip64.so");
  254. dlclose(lib);
  255. return false;
  256. }
  257. // Resolve all symbols we are interested in.
  258. uint64_t hipFlags = 0;
  259. hipDriverProcAddressQueryResult symbolStatus;
  260. hipError_t status = hipSuccess;
  261. #define QUERY_EACH_FN(hipSymbolName, ...) \
  262. status = hipGetProcAddress(#hipSymbolName, \
  263. (void **)&hipSymbolTable.hipSymbolName, \
  264. hipVersion, hipFlags, &symbolStatus); \
  265. if (status != hipSuccess) { \
  266. PyErr_SetString(PyExc_RuntimeError, \
  267. "cannot get address for '" #hipSymbolName \
  268. "' from libamdhip64.so"); \
  269. dlclose(lib); \
  270. return false; \
  271. }
  272. HIP_SYMBOL_LIST(QUERY_EACH_FN, QUERY_EACH_FN)
  273. return true;
  274. }
  275. static inline void gpuAssert(hipError_t code, const char *file, int line) {
  276. {
  277. if (code != HIP_SUCCESS) {
  278. {
  279. const char *prefix = "Triton Error [HIP]: ";
  280. const char *str = hipSymbolTable.hipGetErrorString(code);
  281. char err[TRITON_HIP_MSG_BUFF_SIZE] = {0};
  282. snprintf(err, sizeof(err), "%s Code: %d, Messsage: %s", prefix, code,
  283. str);
  284. PyGILState_STATE gil_state;
  285. gil_state = PyGILState_Ensure();
  286. PyErr_SetString(PyExc_RuntimeError, err);
  287. PyGILState_Release(gil_state);
  288. }
  289. }
  290. }
  291. }
  292. #define HIP_CHECK(ans) \
  293. { \
  294. gpuAssert((ans), __FILE__, __LINE__); \
  295. if (PyErr_Occurred()) \
  296. return NULL; \
  297. }
  298. static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
  299. int device_id;
  300. if (!PyArg_ParseTuple(args, "i", &device_id))
  301. return NULL;
  302. hipDeviceProp_t props;
  303. HIP_CHECK(hipSymbolTable.hipGetDeviceProperties(&props, device_id));
  304. // create a struct to hold device properties
  305. return Py_BuildValue(
  306. "{s:i, s:i, s:i, s:i, s:i, s:i, s:s, s:i, s:i}", "max_shared_mem",
  307. props.sharedMemPerBlock, "max_num_regs", props.regsPerBlock,
  308. "multiprocessor_count", props.multiProcessorCount, "sm_clock_rate",
  309. props.clockRate, "mem_clock_rate", props.memoryClockRate, "mem_bus_width",
  310. props.memoryBusWidth, "arch", props.gcnArchName, "warpSize",
  311. props.warpSize, "max_threads_per_sm", props.maxThreadsPerMultiProcessor);
  312. }
  313. static PyObject *loadBinary(PyObject *self, PyObject *args) {
  314. const char *name;
  315. const char *data;
  316. Py_ssize_t data_size;
  317. int shared;
  318. int device;
  319. if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
  320. &device)) {
  321. return NULL;
  322. }
  323. // set HIP options
  324. hipJitOption opt[] = {hipJitOptionErrorLogBufferSizeBytes,
  325. hipJitOptionErrorLogBuffer,
  326. hipJitOptionInfoLogBufferSizeBytes,
  327. hipJitOptionInfoLogBuffer, hipJitOptionLogVerbose};
  328. const unsigned int errbufsize = 8192;
  329. const unsigned int logbufsize = 8192;
  330. char _err[errbufsize];
  331. char _log[logbufsize];
  332. void *optval[] = {(void *)(uintptr_t)errbufsize, (void *)_err,
  333. (void *)(uintptr_t)logbufsize, (void *)_log, (void *)1};
  334. // launch HIP Binary
  335. hipModule_t mod;
  336. hipFunction_t fun;
  337. HIP_CHECK(hipSymbolTable.hipModuleLoadDataEx(&mod, data, 5, opt, optval))
  338. HIP_CHECK(hipSymbolTable.hipModuleGetFunction(&fun, mod, name));
  339. // get allocated registers and spilled registers from the function
  340. int n_regs = 0;
  341. int n_spills = 0;
  342. int32_t n_max_threads = 0;
  343. hipSymbolTable.hipFuncGetAttribute(&n_regs, HIP_FUNC_ATTRIBUTE_NUM_REGS, fun);
  344. hipSymbolTable.hipFuncGetAttribute(&n_spills,
  345. HIP_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
  346. hipSymbolTable.hipFuncGetAttribute(
  347. &n_max_threads, HIP_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun);
  348. n_spills /= 4;
  349. if (PyErr_Occurred()) {
  350. return NULL;
  351. }
  352. return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
  353. n_spills, n_max_threads);
  354. }
  355. static PyObject *createTDMDescriptor(PyObject *self, PyObject *args) {
  356. int elementBitWidth;
  357. PyObject *blockSize;
  358. int numWarps;
  359. int padInterval;
  360. int padAmount;
  361. PyObject *shape;
  362. PyObject *strides;
  363. unsigned long long globalAddress;
  364. if (!PyArg_ParseTuple(args, "iOiiiOOK", &elementBitWidth, &blockSize,
  365. &numWarps, &padInterval, &padAmount, &shape, &strides,
  366. &globalAddress)) {
  367. return NULL;
  368. }
  369. PyTDMDescriptorObject *descObj = (PyTDMDescriptorObject *)PyObject_CallObject(
  370. (PyObject *)&PyTDMDescriptorType, NULL);
  371. if (!descObj)
  372. return NULL;
  373. PyObject *blockSizeFast = NULL;
  374. PyObject *shapeFast = NULL;
  375. PyObject *stridesFast = NULL;
  376. uint32_t blockSizeInt[2];
  377. uint32_t shapeInt[2];
  378. uint32_t stridesInt[2];
  379. blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
  380. if (!blockSizeFast)
  381. goto cleanup;
  382. int rank = PySequence_Fast_GET_SIZE(blockSizeFast);
  383. if (rank != 2) {
  384. PyErr_SetString(PyExc_RuntimeError, "rank must be 2");
  385. goto cleanup;
  386. }
  387. for (int i = 0; i < rank; ++i) {
  388. PyObject *item = PySequence_Fast_GET_ITEM(blockSizeFast, i);
  389. if (!PyLong_Check(item)) {
  390. PyErr_SetString(PyExc_TypeError, "block size must be an int");
  391. goto cleanup;
  392. }
  393. blockSizeInt[i] = PyLong_AsLong(item);
  394. }
  395. shapeFast = PySequence_Fast(shape, "shape must be a sequence");
  396. if (!shapeFast)
  397. goto cleanup;
  398. if (rank != PySequence_Fast_GET_SIZE(shapeFast)) {
  399. PyErr_SetString(PyExc_RuntimeError, "rank mismatch");
  400. goto cleanup;
  401. }
  402. for (int i = 0; i < rank; ++i) {
  403. PyObject *item = PySequence_Fast_GET_ITEM(shapeFast, i);
  404. if (!PyLong_Check(item)) {
  405. PyErr_SetString(PyExc_TypeError, "shape must be an int");
  406. goto cleanup;
  407. }
  408. shapeInt[i] = PyLong_AsLong(item);
  409. }
  410. stridesFast = PySequence_Fast(strides, "strides must be a sequence");
  411. if (!stridesFast)
  412. goto cleanup;
  413. if (rank != PySequence_Fast_GET_SIZE(stridesFast)) {
  414. PyErr_SetString(PyExc_RuntimeError, "rank mismatch");
  415. goto cleanup;
  416. }
  417. for (int i = 0; i < rank; ++i) {
  418. PyObject *item = PySequence_Fast_GET_ITEM(stridesFast, i);
  419. if (!PyLong_Check(item)) {
  420. PyErr_SetString(PyExc_TypeError, "shape must be an int");
  421. goto cleanup;
  422. }
  423. stridesInt[i] = PyLong_AsLong(item);
  424. }
  425. Py_DECREF(blockSizeFast);
  426. blockSizeFast = NULL;
  427. Py_DECREF(shapeFast);
  428. shapeFast = NULL;
  429. Py_DECREF(stridesFast);
  430. stridesFast = NULL;
  431. bool success = encodeTDMDescriptor(
  432. &descObj->desc, elementBitWidth, blockSizeInt, numWarps, padInterval,
  433. padAmount, shapeInt, stridesInt, globalAddress, rank);
  434. if (!success) {
  435. PyErr_SetString(PyExc_RuntimeError, "Failed to encode TDM descriptor");
  436. goto cleanup;
  437. }
  438. return (PyObject *)descObj;
  439. cleanup:
  440. Py_XDECREF(blockSizeFast);
  441. Py_XDECREF(shapeFast);
  442. Py_XDECREF(stridesFast);
  443. Py_XDECREF(descObj);
  444. return NULL;
  445. }
  446. static PyMethodDef ModuleMethods[] = {
  447. {"load_binary", loadBinary, METH_VARARGS,
  448. "Load provided hsaco into HIP driver"},
  449. {"get_device_properties", getDeviceProperties, METH_VARARGS,
  450. "Get the properties for a given device"},
  451. {"create_tdm_descriptor", createTDMDescriptor, METH_VARARGS,
  452. "create a host-side TDM descriptor"},
  453. {NULL, NULL, 0, NULL} // sentinel
  454. };
  455. static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "hip_utils",
  456. NULL, // documentation
  457. -1, // size
  458. ModuleMethods};
  459. PyMODINIT_FUNC PyInit_hip_utils(void) {
  460. if (!initSymbolTable()) {
  461. return NULL;
  462. }
  463. PyObject *m = PyModule_Create(&ModuleDef);
  464. if (m == NULL) {
  465. return NULL;
  466. }
  467. PyModule_AddFunctions(m, ModuleMethods);
  468. if (PyType_Ready(&PyTDMDescriptorType) < 0)
  469. return NULL;
  470. Py_INCREF(&PyTDMDescriptorType);
  471. PyModule_AddObject(m, "PyTDMDescriptor", (PyObject *)&PyTDMDescriptorType);
  472. return m;
  473. }