driver.c 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558
  1. #define _CRT_SECURE_NO_WARNINGS
  2. #include "cuda.h"
  3. #ifndef _WIN32
  4. #include <dlfcn.h>
  5. #else
  6. #define WIN32_LEAN_AND_MEAN
  7. #include <windows.h>
  8. #endif
  9. #include <stdbool.h>
  10. #include <stdio.h>
  11. #include <stdlib.h>
  12. #define PY_SSIZE_T_CLEAN
  13. #include <Python.h>
  14. typedef struct {
  15. PyObject_HEAD _Alignas(128) CUtensorMap tensorMap;
  16. } PyCUtensorMapObject;
  17. // Raises a Python exception and returns false if code is not CUDA_SUCCESS.
  18. static bool gpuAssert(CUresult code, const char *file, int line) {
  19. if (code == CUDA_SUCCESS)
  20. return true;
  21. const char *prefix = "Triton Error [CUDA]: ";
  22. const char *str;
  23. cuGetErrorString(code, &str);
  24. char err[1024] = {0};
  25. strcat(err, prefix);
  26. strcat(err, str);
  27. PyGILState_STATE gil_state;
  28. gil_state = PyGILState_Ensure();
  29. PyErr_SetString(PyExc_RuntimeError, err);
  30. PyGILState_Release(gil_state);
  31. return false;
  32. }
  33. // To be used only *outside* a Py_{BEGIN,END}_ALLOW_THREADS block.
  34. #define CUDA_CHECK_AND_RETURN_NULL(ans) \
  35. do { \
  36. if (!gpuAssert((ans), __FILE__, __LINE__)) \
  37. goto cleanup; \
  38. } while (0)
  39. // To be used inside a Py_{BEGIN,END}_ALLOW_THREADS block.
  40. #define CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(ans) \
  41. do { \
  42. if (!gpuAssert((ans), __FILE__, __LINE__)) { \
  43. PyEval_RestoreThread(_save); \
  44. return NULL; \
  45. } \
  46. } while (0)
  47. // Used to check if functions exist in old CUDA driver versions.
  48. #define INITIALIZE_FUNCTION_POINTER_IF_NULL(funcPointer, initializerFunction) \
  49. do { \
  50. if ((funcPointer) == NULL) { \
  51. (funcPointer) = (initializerFunction)(); \
  52. if ((funcPointer) == NULL) { \
  53. goto cleanup; \
  54. } \
  55. } \
  56. } while (0)
  57. static PyObject *getDeviceProperties(PyObject *self, PyObject *args) {
  58. int device_id;
  59. if (!PyArg_ParseTuple(args, "i", &device_id))
  60. return NULL;
  61. // Get device handle
  62. CUdevice device;
  63. cuDeviceGet(&device, device_id);
  64. // create a struct to hold device properties
  65. int max_shared_mem;
  66. int max_num_regs;
  67. int multiprocessor_count;
  68. int warp_size;
  69. int sm_clock_rate;
  70. int mem_clock_rate;
  71. int mem_bus_width;
  72. CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
  73. &max_shared_mem, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
  74. device));
  75. CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
  76. &max_num_regs, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_BLOCK, device));
  77. CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
  78. &multiprocessor_count, CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device));
  79. CUDA_CHECK_AND_RETURN_NULL(
  80. cuDeviceGetAttribute(&warp_size, CU_DEVICE_ATTRIBUTE_WARP_SIZE, device));
  81. CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
  82. &sm_clock_rate, CU_DEVICE_ATTRIBUTE_CLOCK_RATE, device));
  83. CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
  84. &mem_clock_rate, CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE, device));
  85. CUDA_CHECK_AND_RETURN_NULL(cuDeviceGetAttribute(
  86. &mem_bus_width, CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH, device));
  87. return Py_BuildValue("{s:i, s:i, s:i, s:i, s:i, s:i, s:i}", "max_shared_mem",
  88. max_shared_mem, "max_num_regs", max_num_regs,
  89. "multiprocessor_count", multiprocessor_count, "warpSize",
  90. warp_size, "sm_clock_rate", sm_clock_rate,
  91. "mem_clock_rate", mem_clock_rate, "mem_bus_width",
  92. mem_bus_width);
  93. cleanup:
  94. return NULL;
  95. }
  96. static PyObject *loadBinary(PyObject *self, PyObject *args) {
  97. const char *name;
  98. const char *data;
  99. Py_ssize_t data_size;
  100. int shared;
  101. int device;
  102. if (!PyArg_ParseTuple(args, "ss#ii", &name, &data, &data_size, &shared,
  103. &device)) {
  104. return NULL;
  105. }
  106. CUfunction fun;
  107. CUmodule mod;
  108. int32_t n_regs = 0;
  109. int32_t n_spills = 0;
  110. int32_t n_max_threads = 0;
  111. // create driver handles
  112. CUcontext pctx = 0;
  113. Py_BEGIN_ALLOW_THREADS;
  114. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&pctx));
  115. if (!pctx) {
  116. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
  117. cuDevicePrimaryCtxRetain(&pctx, device));
  118. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(pctx));
  119. }
  120. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuModuleLoadData(&mod, data));
  121. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
  122. cuModuleGetFunction(&fun, mod, name));
  123. // get allocated registers and spilled registers from the function
  124. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
  125. cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
  126. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
  127. cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
  128. n_spills /= 4;
  129. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
  130. &n_max_threads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));
  131. // set dynamic shared memory if necessary
  132. int shared_optin;
  133. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
  134. &shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN,
  135. device));
  136. if (shared > 49152 && shared_optin > 49152) {
  137. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
  138. cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED));
  139. int shared_total, shared_static;
  140. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuDeviceGetAttribute(
  141. &shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR,
  142. device));
  143. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
  144. &shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
  145. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
  146. cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
  147. shared_optin - shared_static));
  148. }
  149. Py_END_ALLOW_THREADS;
  150. if (PyErr_Occurred()) {
  151. return NULL;
  152. }
  153. return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
  154. n_spills, n_max_threads);
  155. }
  156. typedef CUresult (*cuOccupancyMaxActiveClusters_t)(
  157. int *numClusters, CUfunction func, const CUlaunchConfig *config);
  158. typedef CUresult (*cuTensorMapEncodeTiled_t)(
  159. CUtensorMap *tensorMap, CUtensorMapDataType tensorDataType,
  160. cuuint32_t tensorRank, void *globalAddress, const cuuint64_t *globalDim,
  161. const cuuint64_t *globalStrides, const cuuint32_t *boxDim,
  162. const cuuint32_t *elementStrides, CUtensorMapInterleave interleave,
  163. CUtensorMapSwizzle swizzle, CUtensorMapL2promotion l2Promotion,
  164. CUtensorMapFloatOOBfill oobFill);
  165. #ifndef _WIN32
  166. #define defineGetFunctionHandle(name, symbolName) \
  167. static symbolName##_t name() { \
  168. /* Open the shared library */ \
  169. void *libHandle = dlopen("libcuda.so.1", RTLD_LAZY); \
  170. if (!libHandle) { \
  171. PyErr_SetString(PyExc_RuntimeError, "Failed to open libcuda.so.1"); \
  172. return NULL; \
  173. } \
  174. /* Clear any existing error */ \
  175. dlerror(); \
  176. symbolName##_t funcHandle = (symbolName##_t)dlsym(libHandle, #symbolName); \
  177. /* Check for errors */ \
  178. const char *err = dlerror(); \
  179. if (err) { \
  180. PyErr_SetString(PyExc_RuntimeError, \
  181. "Failed to retrieve " #symbolName " from libcuda.so.1"); \
  182. dlclose(libHandle); \
  183. return NULL; \
  184. } \
  185. return funcHandle; \
  186. }
  187. #else
  188. #define defineGetFunctionHandle(name, symbolName) \
  189. static symbolName##_t name() { \
  190. /* Open the shared library */ \
  191. HMODULE handle = LoadLibraryA("nvcuda.dll"); \
  192. if (!handle) { \
  193. PyErr_SetString(PyExc_RuntimeError, "Failed to open nvcuda.dll"); \
  194. return NULL; \
  195. } \
  196. symbolName##_t funcHandle = \
  197. (symbolName##_t)GetProcAddress((HMODULE)handle, #symbolName); \
  198. /* Check for errors */ \
  199. long err = GetLastError(); \
  200. if (err) { \
  201. PyErr_SetString(PyExc_RuntimeError, \
  202. "Failed to retrieve " #symbolName " from nvcuda.dll"); \
  203. return NULL; \
  204. } \
  205. return funcHandle; \
  206. }
  207. #endif
  208. defineGetFunctionHandle(getCuOccupancyMaxActiveClustersHandle,
  209. cuOccupancyMaxActiveClusters);
  210. defineGetFunctionHandle(getCuTensorMapEncodeTiledHandle,
  211. cuTensorMapEncodeTiled);
  212. static PyObject *occupancyMaxActiveClusters(PyObject *self, PyObject *args) {
  213. int clusterDim = -1, maxActiveClusters = -1;
  214. int shared = 0;
  215. CUfunction func;
  216. if (!PyArg_ParseTuple(args, "Kii", &func, &shared, &clusterDim)) {
  217. return NULL;
  218. }
  219. // Let each SM have one block
  220. int maxActiveBlocks = 1;
  221. Py_BEGIN_ALLOW_THREADS;
  222. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute(
  223. func, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared));
  224. Py_END_ALLOW_THREADS;
  225. CUlaunchAttribute launchAttr[1];
  226. launchAttr[0].id = CU_LAUNCH_ATTRIBUTE_CLUSTER_DIMENSION;
  227. launchAttr[0].value.clusterDim.x = clusterDim;
  228. launchAttr[0].value.clusterDim.y = 1;
  229. launchAttr[0].value.clusterDim.z = 1;
  230. CUlaunchConfig config;
  231. config.gridDimX = clusterDim * maxActiveBlocks;
  232. config.gridDimY = 1;
  233. config.gridDimZ = 1;
  234. config.blockDimX = 128;
  235. config.blockDimY = 1;
  236. config.blockDimZ = 1;
  237. config.sharedMemBytes = shared;
  238. config.hStream = 0;
  239. config.numAttrs = 1;
  240. config.attrs = launchAttr;
  241. static cuOccupancyMaxActiveClusters_t cuOccupancyMaxActiveClusters = NULL;
  242. INITIALIZE_FUNCTION_POINTER_IF_NULL(cuOccupancyMaxActiveClusters,
  243. getCuOccupancyMaxActiveClustersHandle);
  244. Py_BEGIN_ALLOW_THREADS;
  245. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncSetAttribute(
  246. func, CU_FUNC_ATTRIBUTE_NON_PORTABLE_CLUSTER_SIZE_ALLOWED, 1));
  247. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
  248. cuOccupancyMaxActiveClusters(&maxActiveClusters, func, &config));
  249. Py_END_ALLOW_THREADS;
  250. return PyLong_FromLong(maxActiveClusters);
  251. cleanup:
  252. return NULL;
  253. }
  254. static PyObject *setPrintfFifoSize(PyObject *self, PyObject *args) {
  255. long size;
  256. if (!PyArg_ParseTuple(args, "l", &size)) {
  257. return NULL;
  258. }
  259. if (size < 0) {
  260. PyErr_SetString(PyExc_ValueError, "fifo size must be non-negative");
  261. return NULL;
  262. }
  263. Py_BEGIN_ALLOW_THREADS;
  264. // Ensure we have an active context.
  265. CUcontext ctx = NULL;
  266. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxGetCurrent(&ctx));
  267. if (!ctx) {
  268. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
  269. cuDevicePrimaryCtxRetain(&ctx, /*device=*/0));
  270. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuCtxSetCurrent(ctx));
  271. }
  272. // We can't set the fifo size after running a kernel that calls printf. This
  273. // is true even if the set() call is a nop and the new size is the same as the
  274. // old size.
  275. //
  276. // This is unfriendly, so check if the old size matches the new size, and skip
  277. // the set() call if so.
  278. size_t oldSize = 0;
  279. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
  280. cuCtxGetLimit(&oldSize, CU_LIMIT_PRINTF_FIFO_SIZE));
  281. if (oldSize != size) {
  282. CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
  283. cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, size));
  284. }
  285. Py_END_ALLOW_THREADS;
  286. Py_RETURN_NONE;
  287. }
  288. static PyObject *PyCUtensorMap_alloc(PyTypeObject *type, Py_ssize_t n_items) {
  289. PyCUtensorMapObject *self = NULL;
  290. void *mem = NULL;
  291. size_t size = type->tp_basicsize;
  292. #ifdef _WIN32
  293. mem = _aligned_malloc(size, 128);
  294. if (mem == NULL) {
  295. #else
  296. if (posix_memalign(&mem, 128, size) != 0) {
  297. #endif
  298. PyErr_NoMemory();
  299. return NULL;
  300. }
  301. self = (PyCUtensorMapObject *)mem;
  302. PyObject_INIT(self, type);
  303. return (PyObject *)self;
  304. }
  305. static void PyCUtensorMap_dealloc(PyObject *self) {
  306. Py_TYPE(self)->tp_free(self);
  307. }
  308. static void PyCUtensorMap_free(void *ptr) {
  309. #ifdef _WIN32
  310. _aligned_free(ptr);
  311. #else
  312. free(ptr);
  313. #endif
  314. }
  315. // clang-format off
  316. static PyTypeObject PyCUtensorMapType = {
  317. PyVarObject_HEAD_INIT(NULL, 0)
  318. .tp_name = "triton.backends.nvidia.PyCUtensorMap",
  319. .tp_basicsize = sizeof(PyCUtensorMapObject),
  320. .tp_itemsize = 0,
  321. .tp_flags = Py_TPFLAGS_DEFAULT,
  322. .tp_doc = "<PyCUtensorMap object>",
  323. .tp_new = PyType_GenericNew,
  324. .tp_alloc = PyCUtensorMap_alloc,
  325. .tp_dealloc = (destructor)PyCUtensorMap_dealloc,
  326. .tp_free = PyCUtensorMap_free,
  327. };
  328. // clang-format on
  329. static PyObject *fillTMADescriptor(PyObject *self, PyObject *args) {
  330. unsigned long long global_address;
  331. int swizzle;
  332. int elemSize;
  333. int elemType;
  334. PyObject *blockSize;
  335. PyObject *shape;
  336. PyObject *strides;
  337. int padding;
  338. if (!PyArg_ParseTuple(args, "KiiiOOOi", &global_address, &swizzle, &elemSize,
  339. &elemType, &blockSize, &shape, &strides, &padding)) {
  340. return NULL;
  341. }
  342. PyCUtensorMapObject *desc = (PyCUtensorMapObject *)PyObject_CallObject(
  343. (PyObject *)&PyCUtensorMapType, NULL);
  344. if (!desc) {
  345. return NULL;
  346. }
  347. PyObject *blockSizeFast = NULL;
  348. PyObject *shapeFast = NULL;
  349. PyObject *stridesFast = NULL;
  350. uint32_t blockSizeInt[5];
  351. uint64_t shapeInt[5];
  352. uint64_t stridesLL[5];
  353. blockSizeFast = PySequence_Fast(blockSize, "blockSize must be a sequence");
  354. if (!blockSizeFast)
  355. goto cleanup;
  356. int rank = PySequence_Fast_GET_SIZE(blockSizeFast);
  357. for (int i = 0; i < rank; ++i) {
  358. PyObject *item = PySequence_Fast_GET_ITEM(blockSizeFast, i);
  359. if (!PyLong_Check(item)) {
  360. PyErr_SetString(PyExc_TypeError, "block size must be an int");
  361. goto cleanup;
  362. }
  363. blockSizeInt[rank - i - 1] = PyLong_AsLongLong(item);
  364. }
  365. shapeFast = PySequence_Fast(shape, "shape must be a sequence");
  366. if (!shapeFast)
  367. goto cleanup;
  368. if (rank != PySequence_Fast_GET_SIZE(shapeFast)) {
  369. PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
  370. goto cleanup;
  371. }
  372. for (int i = 0; i < rank; ++i) {
  373. PyObject *item = PySequence_Fast_GET_ITEM(shapeFast, i);
  374. if (!PyLong_Check(item)) {
  375. PyErr_SetString(PyExc_TypeError, "shape must be an int");
  376. goto cleanup;
  377. }
  378. shapeInt[rank - i - 1] = PyLong_AsLong(item);
  379. }
  380. stridesFast = PySequence_Fast(strides, "strides must be a sequence");
  381. if (!stridesFast)
  382. goto cleanup;
  383. if (rank != PySequence_Fast_GET_SIZE(stridesFast)) {
  384. PyErr_SetString(PyExc_RuntimeError, "Rank mismatch");
  385. goto cleanup;
  386. }
  387. for (int i = 0; i + 1 < rank; ++i) {
  388. PyObject *item = PySequence_Fast_GET_ITEM(stridesFast, i);
  389. if (!PyLong_Check(item)) {
  390. PyErr_SetString(PyExc_TypeError, "shape must be an int");
  391. goto cleanup;
  392. }
  393. stridesLL[rank - i - 2] = elemSize * PyLong_AsLongLong(item);
  394. }
  395. stridesLL[rank - 1] =
  396. shapeInt[rank - 1] * (rank == 1 ? elemSize : stridesLL[rank - 2]);
  397. Py_DECREF(blockSizeFast);
  398. blockSizeFast = NULL;
  399. Py_DECREF(shapeFast);
  400. shapeFast = NULL;
  401. Py_DECREF(stridesFast);
  402. stridesFast = NULL;
  403. CUtensorMapFloatOOBfill fill =
  404. (padding == 1) ? CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA
  405. : CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
  406. uint32_t elementStrides[5] = {1, 1, 1, 1, 1};
  407. static cuTensorMapEncodeTiled_t cuTensorMapEncodeTiled = NULL;
  408. INITIALIZE_FUNCTION_POINTER_IF_NULL(cuTensorMapEncodeTiled,
  409. getCuTensorMapEncodeTiledHandle);
  410. CUresult res = cuTensorMapEncodeTiled(
  411. &desc->tensorMap, elemType, rank, (void *)global_address, shapeInt,
  412. stridesLL, blockSizeInt, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
  413. swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B, fill);
  414. if (res != CUDA_SUCCESS) {
  415. const char *str;
  416. cuGetErrorString(res, &str);
  417. char err[4096] = {0};
  418. size_t off = 0;
  419. off += snprintf(
  420. err + off, sizeof(err) - off,
  421. "Triton Error [CUDA]: Failed to create tensor map descriptor: %s\n",
  422. str ? str : "Unknown error");
  423. off += snprintf(err + off, sizeof(err) - off,
  424. "elemType=%d rank=%d global_address=0x%llx elemSize=%d "
  425. "swizzle=%d padding=%d\n",
  426. elemType, rank, (unsigned long long)global_address,
  427. elemSize, swizzle, padding);
  428. off += snprintf(err + off, sizeof(err) - off, "shape=[");
  429. for (int i = 0; i < rank; ++i) {
  430. off +=
  431. snprintf(err + off, sizeof(err) - off, "%llu%s",
  432. (unsigned long long)shapeInt[i], (i + 1 < rank) ? ", " : "");
  433. }
  434. off += snprintf(err + off, sizeof(err) - off, "]\n");
  435. off += snprintf(err + off, sizeof(err) - off, "strides=[");
  436. for (int i = 0; i < rank; ++i) {
  437. off += snprintf(err + off, sizeof(err) - off, "%llu%s",
  438. (unsigned long long)stridesLL[i],
  439. (i + 1 < rank) ? ", " : "");
  440. }
  441. off += snprintf(err + off, sizeof(err) - off, "]\n");
  442. off += snprintf(err + off, sizeof(err) - off, "blockSize=[");
  443. for (int i = 0; i < rank; ++i) {
  444. off += snprintf(err + off, sizeof(err) - off, "%u%s",
  445. (unsigned)blockSizeInt[i], (i + 1 < rank) ? ", " : "");
  446. }
  447. off += snprintf(err + off, sizeof(err) - off, "] elementStrides=[");
  448. for (int i = 0; i < rank; ++i) {
  449. off += snprintf(err + off, sizeof(err) - off, "%u%s",
  450. (unsigned)elementStrides[i], (i + 1 < rank) ? ", " : "");
  451. }
  452. off += snprintf(err + off, sizeof(err) - off, "]\n");
  453. PyErr_SetString(PyExc_RuntimeError, err);
  454. goto cleanup;
  455. }
  456. return (PyObject *)desc;
  457. cleanup:
  458. Py_XDECREF(blockSizeFast);
  459. Py_XDECREF(shapeFast);
  460. Py_XDECREF(stridesFast);
  461. Py_XDECREF(desc);
  462. return NULL;
  463. }
  464. static PyMethodDef ModuleMethods[] = {
  465. {"load_binary", loadBinary, METH_VARARGS,
  466. "Load provided cubin into CUDA driver"},
  467. {"get_device_properties", getDeviceProperties, METH_VARARGS,
  468. "Get the properties for a given device"},
  469. {"cuOccupancyMaxActiveClusters", occupancyMaxActiveClusters, METH_VARARGS,
  470. "Python interface for cuOccupancyMaxActiveClusters function"},
  471. {"set_printf_fifo_size", setPrintfFifoSize, METH_VARARGS,
  472. "Python interface for cuCtxSetLimit(CU_LIMIT_PRINTF_FIFO_SIZE, x), which "
  473. "controls how many bytes can be streamed from kernels before data starts "
  474. "being dropped. This inherits all the limitations of this call; in "
  475. "particular it's an error to change this value after launching any kernel "
  476. "that calls printf()."},
  477. {"fill_tma_descriptor", fillTMADescriptor, METH_VARARGS, "doc"},
  478. {NULL, NULL, 0, NULL} // sentinel
  479. };
  480. static struct PyModuleDef ModuleDef = {PyModuleDef_HEAD_INIT, "cuda_utils",
  481. NULL, // documentation
  482. -1, // size
  483. ModuleMethods};
  484. PyMODINIT_FUNC PyInit_cuda_utils(void) {
  485. if (PyType_Ready(&PyCUtensorMapType) < 0) {
  486. return NULL;
  487. }
  488. PyObject *m = PyModule_Create(&ModuleDef);
  489. if (m == NULL) {
  490. return NULL;
  491. }
  492. PyModule_AddFunctions(m, ModuleMethods);
  493. Py_INCREF(&PyCUtensorMapType);
  494. PyModule_AddObject(m, "PyCUtensorMap", (PyObject *)&PyCUtensorMapType);
  495. return m;
  496. }