_utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. import ctypes
  2. import sys
  3. from typing import Any
  4. import torch
  5. # The _get_device_index has been moved to torch.utils._get_device_index
  6. from torch._utils import _get_device_index as _torch_get_device_index
  7. def _get_hip_runtime_library() -> ctypes.CDLL:
  8. # If ROCm python packages are available, query the OS-independent absolute
  9. # path to the library provided by those packages, including any version suffix.
  10. # See https://github.com/ROCm/TheRock/blob/main/docs/packaging/python_packaging.md#dynamic-library-resolution
  11. try:
  12. # pyrefly: ignore [import-error, missing-import]
  13. import rocm_sdk
  14. lib = ctypes.CDLL(str(rocm_sdk.find_libraries("amdhip64")[0]))
  15. except (ImportError, IndexError):
  16. if sys.platform == "win32":
  17. lib = ctypes.CDLL(f"amdhip64_{torch.version.hip[0]}.dll")
  18. else: # Unix-based systems
  19. lib = ctypes.CDLL("libamdhip64.so")
  20. lib.cuGetErrorString = lib.hipGetErrorString # type: ignore[attr-defined]
  21. lib.cuModuleLoadData = lib.hipModuleLoadData # type: ignore[attr-defined]
  22. lib.cuModuleGetFunction = lib.hipModuleGetFunction # type: ignore[attr-defined]
  23. lib.cuLaunchKernel = lib.hipModuleLaunchKernel # type: ignore[attr-defined]
  24. lib.cuFuncSetAttribute = lib.hipFuncSetAttribute # type: ignore[attr-defined]
  25. return lib
  26. def _get_cuda_library() -> ctypes.CDLL:
  27. if sys.platform == "win32":
  28. return ctypes.CDLL("nvcuda.dll")
  29. else: # Unix-based systems
  30. return ctypes.CDLL("libcuda.so.1")
  31. # Load GPU driver runtime
  32. def _get_gpu_runtime_library() -> ctypes.CDLL:
  33. if torch.version.hip:
  34. return _get_hip_runtime_library()
  35. else:
  36. return _get_cuda_library()
  37. # Helper: check CUDA errors
  38. def _check_cuda(result: int) -> None:
  39. if result == 0:
  40. return
  41. err_str = ctypes.c_char_p()
  42. libcuda = _get_gpu_runtime_library() # Get reference to CUDA library
  43. libcuda.cuGetErrorString(result, ctypes.byref(err_str))
  44. error_message = (
  45. err_str.value.decode() if err_str.value is not None else "Unknown CUDA error"
  46. )
  47. raise RuntimeError(f"CUDA error: {error_message}")
  48. def _get_hiprtc_library() -> ctypes.CDLL:
  49. try:
  50. # pyrefly: ignore [import-error, missing-import]
  51. import rocm_sdk
  52. lib = ctypes.CDLL(str(rocm_sdk.find_libraries("hiprtc")[0]))
  53. except (ImportError, IndexError):
  54. if sys.platform == "win32":
  55. version_str = "".join(
  56. ["0", torch.version.hip[0], "0", torch.version.hip[2]]
  57. )
  58. lib = ctypes.CDLL(f"hiprtc{version_str}.dll")
  59. else:
  60. lib = ctypes.CDLL("libhiprtc.so")
  61. # Provide aliases for HIP RTC functions to match NVRTC API
  62. lib.nvrtcGetErrorString = lib.hiprtcGetErrorString # type: ignore[attr-defined]
  63. lib.nvrtcCreateProgram = lib.hiprtcCreateProgram # type: ignore[attr-defined]
  64. lib.nvrtcDestroyProgram = lib.hiprtcDestroyProgram # type: ignore[attr-defined]
  65. lib.nvrtcCompileProgram = lib.hiprtcCompileProgram # type: ignore[attr-defined]
  66. lib.nvrtcGetCUBINSize = lib.hiprtcGetCodeSize # type: ignore[attr-defined]
  67. lib.nvrtcGetCUBIN = lib.hiprtcGetCode # type: ignore[attr-defined]
  68. lib.nvrtcGetProgramLogSize = lib.hiprtcGetProgramLogSize # type: ignore[attr-defined]
  69. lib.nvrtcGetProgramLog = lib.hiprtcGetProgramLog # type: ignore[attr-defined]
  70. lib.nvrtcAddNameExpression = lib.hiprtcAddNameExpression # type: ignore[attr-defined]
  71. lib.nvrtcGetLoweredName = lib.hiprtcGetLoweredName # type: ignore[attr-defined]
  72. return lib
  73. def _get_nvrtc_library() -> ctypes.CDLL:
  74. major_version = int(torch.version.cuda.split(".")[0]) # type: ignore[union-attr]
  75. if sys.platform == "win32":
  76. nvrtc_libs = [
  77. f"nvrtc64_{major_version}0_0.dll",
  78. ]
  79. else:
  80. nvrtc_libs = [
  81. f"libnvrtc.so.{major_version}",
  82. "libnvrtc.so", # Fallback to unversioned
  83. ]
  84. for lib_name in nvrtc_libs:
  85. try:
  86. return ctypes.CDLL(lib_name)
  87. except OSError:
  88. continue
  89. raise OSError("Could not find any NVRTC library")
  90. def _get_gpu_rtc_library() -> ctypes.CDLL:
  91. # Since PyTorch already loads the GPU RTC library, we can use the system library
  92. # which should be compatible with PyTorch's version
  93. if torch.version.hip:
  94. return _get_hiprtc_library()
  95. else:
  96. return _get_nvrtc_library()
  97. def _get_gpu_rtc_compatible_flags() -> list[str]:
  98. """
  99. Get HIPCC/NVCC flags that are compatible with NVRTC compilation.
  100. Returns:
  101. List of HIPCC/NVCC flags that can be safely used with NVRTC.
  102. """
  103. from torch.utils.cpp_extension import COMMON_HIPCC_FLAGS, COMMON_NVCC_FLAGS
  104. nvrtc_unsupported_flags = {
  105. "--expt-relaxed-constexpr",
  106. }
  107. # Filter out unsupported flags
  108. compatible_flags = [
  109. flag for flag in COMMON_NVCC_FLAGS if flag not in nvrtc_unsupported_flags
  110. ]
  111. if torch.version.hip:
  112. compatible_flags.extend(COMMON_HIPCC_FLAGS)
  113. return compatible_flags
  114. def _nvrtc_compile(
  115. kernel_source: str,
  116. kernel_name: str,
  117. compute_capability: str | None = None,
  118. cuda_include_dirs: list | None = None,
  119. nvcc_options: list | None = None,
  120. auto_pch: bool = False,
  121. ) -> tuple[bytes, str]:
  122. """
  123. Compiles a CUDA kernel using NVRTC and returns the PTX code.
  124. Args:
  125. kernel_source (str): The CUDA kernel source code as a string
  126. kernel_name (str): The name of the kernel function to compile
  127. compute_capability (str, None): The compute capability to target (e.g., "86").
  128. If None, will detect from current device.
  129. cuda_include_dirs (list, None): List of directories containing CUDA headers
  130. nvcc_options (list, None): Additional options to pass to NVRTC
  131. auto_pch (bool): Enable automatic precompiled headers (CUDA 12.8+)
  132. Returns:
  133. Tuple[bytes, str]: The compiled PTX code and mangled kernel name
  134. """
  135. # Ensure CUDA is initialized
  136. import torch.cuda
  137. # Load NVRTC library
  138. libnvrtc = _get_gpu_rtc_library()
  139. # NVRTC constants
  140. NVRTC_SUCCESS = 0
  141. # Helper: check NVRTC errors
  142. def check_nvrtc(result: int) -> None:
  143. if result != NVRTC_SUCCESS:
  144. err_str = ctypes.c_char_p()
  145. libnvrtc.nvrtcGetErrorString(result, ctypes.byref(err_str))
  146. error_message = (
  147. err_str.value.decode()
  148. if err_str.value is not None
  149. else "Unknown CUDA error"
  150. )
  151. raise RuntimeError(f"CUDA error: {error_message}")
  152. # Convert source to bytes
  153. source_bytes = kernel_source.encode("utf-8")
  154. # Get compute capability if not provided
  155. if compute_capability is None:
  156. props = torch.cuda.get_device_properties(torch.cuda.current_device())
  157. if torch.version.hip:
  158. compute_capability = f"{props.gcnArchName}"
  159. else:
  160. compute_capability = f"{props.major}{props.minor}"
  161. # Prepare compilation options
  162. options = []
  163. if torch.version.hip:
  164. options.append(f"--offload-arch={compute_capability}".encode())
  165. else:
  166. options.append(f"--gpu-architecture=sm_{compute_capability}".encode())
  167. # Auto-detect and add CUDA include paths
  168. from torch.utils.cpp_extension import include_paths
  169. cuda_include_paths = include_paths("cuda")
  170. for cuda_path in cuda_include_paths:
  171. options.append(f"-I{cuda_path}".encode())
  172. # Add custom include directories
  173. if cuda_include_dirs:
  174. for directory in cuda_include_dirs:
  175. options.append(f"-I{directory}".encode())
  176. # Enable automatic precompiled headers (CUDA 12.8+)
  177. if auto_pch:
  178. if str(torch.version.cuda) < "12.8":
  179. raise AssertionError(f"PCH requires CUDA 12.8+, got {torch.version.cuda}")
  180. if nvcc_options is None:
  181. nvcc_options = []
  182. nvcc_options.append("--pch")
  183. # Add custom NVCC options
  184. if nvcc_options:
  185. for option in nvcc_options:
  186. options.append(option.encode("utf-8"))
  187. nvrtc_compatible_flags = _get_gpu_rtc_compatible_flags()
  188. options.extend([flag.encode("utf-8") for flag in nvrtc_compatible_flags])
  189. # Convert options to C array
  190. num_options = len(options)
  191. options_array = (ctypes.c_char_p * num_options)(*options)
  192. # Create program
  193. prog = ctypes.c_void_p()
  194. check_nvrtc(
  195. libnvrtc.nvrtcCreateProgram(
  196. ctypes.byref(prog),
  197. source_bytes,
  198. f"{kernel_name}.cu".encode(),
  199. 0,
  200. None,
  201. None,
  202. )
  203. )
  204. # Add kernel name, which can be a template expression
  205. c_kernel_name = kernel_name.encode("utf-8")
  206. check_nvrtc(libnvrtc.nvrtcAddNameExpression(prog, c_kernel_name))
  207. # Compile program
  208. res = libnvrtc.nvrtcCompileProgram(prog, num_options, options_array)
  209. # Handle compilation errors
  210. if res != NVRTC_SUCCESS:
  211. # Get log
  212. log_size = ctypes.c_size_t()
  213. libnvrtc.nvrtcGetProgramLogSize(prog, ctypes.byref(log_size))
  214. log = ctypes.create_string_buffer(log_size.value)
  215. libnvrtc.nvrtcGetProgramLog(prog, log)
  216. raise RuntimeError(f"Kernel compilation failed:\n{log.value.decode()}")
  217. # Get binary
  218. binary_size = ctypes.c_size_t()
  219. check_nvrtc(libnvrtc.nvrtcGetCUBINSize(prog, ctypes.byref(binary_size)))
  220. binary = ctypes.create_string_buffer(binary_size.value)
  221. check_nvrtc(libnvrtc.nvrtcGetCUBIN(prog, binary))
  222. # Get mangled name
  223. c_mangled_name = ctypes.c_char_p()
  224. check_nvrtc(
  225. libnvrtc.nvrtcGetLoweredName(prog, c_kernel_name, ctypes.byref(c_mangled_name))
  226. )
  227. if c_mangled_name.value is not None:
  228. mangled_name = c_mangled_name.value.decode() # make a copy
  229. else:
  230. mangled_name = ""
  231. libnvrtc.nvrtcDestroyProgram(ctypes.byref(prog))
  232. # For some reason, ".value" causes the string to be truncated,
  233. # likely due to the presence of '\0' in the string. So we use .raw instead.
  234. return binary.raw, mangled_name
  235. class _CudaModule:
  236. def __init__(self, module: ctypes.c_void_p) -> None:
  237. self._module = module
  238. self._kernels: dict[str, _CudaKernel] = {}
  239. def __getattr__(self, name: str) -> "_CudaKernel":
  240. if name in self._kernels:
  241. return self._kernels[name]
  242. # Import the CUDA library inside the method
  243. # pyrefly: ignore [missing-module-attribute]
  244. from torch.cuda._utils import _get_gpu_runtime_library
  245. libcuda = _get_gpu_runtime_library()
  246. func = ctypes.c_void_p()
  247. try:
  248. _check_cuda(
  249. libcuda.cuModuleGetFunction(
  250. ctypes.byref(func), self._module, name.encode("utf-8")
  251. )
  252. )
  253. kernel = _CudaKernel(func, self._module)
  254. self._kernels[name] = kernel
  255. return kernel
  256. except RuntimeError as err:
  257. raise AttributeError(f"No kernel named '{name}' in this module") from err
  258. class _CudaKernel:
  259. """
  260. Represents a compiled CUDA kernel that can be called with PyTorch tensors.
  261. """
  262. def __init__(self, func: ctypes.c_void_p, module: ctypes.c_void_p) -> None:
  263. self.func = func
  264. self.module = module
  265. self._max_shared_mem_bytes = 0
  266. def __call__(
  267. self,
  268. grid: tuple[int, int, int] = (1, 1, 1),
  269. block: tuple[int, int, int] = (1, 1, 1),
  270. args: list | None = None,
  271. shared_mem: int = 0,
  272. stream: Any | None = None,
  273. ) -> None:
  274. """
  275. Call the compiled CUDA kernel
  276. Args:
  277. grid (tuple): Grid dimensions (grid_x, grid_y, grid_z)
  278. block (tuple): Block dimensions (block_x, block_y, block_z)
  279. args (list): List of arguments to pass to the kernel.
  280. PyTorch tensor arguments will be automatically converted to pointers.
  281. shared_mem (int): Shared memory size in bytes
  282. stream (torch.cuda.Stream): CUDA stream to use. If None, uses current stream.
  283. """
  284. import torch
  285. libcuda = torch.cuda._utils._get_gpu_runtime_library()
  286. if not args:
  287. args = []
  288. # Process arguments and convert tensors to pointers
  289. processed_args: list[ctypes.c_void_p] = []
  290. c_args = []
  291. for arg in args:
  292. if isinstance(arg, torch.Tensor):
  293. if not arg.is_cuda and not (arg.is_cpu and arg.is_pinned()):
  294. raise ValueError(
  295. "All tensor arguments must be CUDA tensors or pinned CPU tensors"
  296. )
  297. # Get pointer to tensor data
  298. ptr = ctypes.c_void_p(arg.data_ptr())
  299. processed_args.append(ptr)
  300. c_args.append(ctypes.byref(ptr))
  301. elif isinstance(arg, int):
  302. # Convert integers to C int
  303. c_int = ctypes.c_int(arg)
  304. # Store the C int for reference keeping, not in processed_args
  305. c_args.append(ctypes.byref(c_int))
  306. elif isinstance(arg, float):
  307. # Python floats are doubles - use double by default
  308. c_double = ctypes.c_double(arg)
  309. # Store the C double for reference keeping, not in processed_args
  310. c_args.append(ctypes.byref(c_double))
  311. else:
  312. raise TypeError(f"Unsupported argument type: {type(arg)}")
  313. # Convert to array of void pointers
  314. c_args_array = (ctypes.c_void_p * len(c_args))()
  315. for i, arg in enumerate(c_args):
  316. c_args_array[i] = ctypes.cast(arg, ctypes.c_void_p)
  317. # Get the stream
  318. if stream is None:
  319. # Defer import to avoid circular imports
  320. import torch.cuda
  321. stream = torch.cuda.current_stream()
  322. # Check if kernel requires large shared memory but hasn't been configured
  323. if shared_mem >= 48 * 1024 and (
  324. self._max_shared_mem_bytes == 0 or shared_mem > self._max_shared_mem_bytes
  325. ):
  326. configured_msg = (
  327. "not configured"
  328. if self._max_shared_mem_bytes == 0
  329. else f"only {self._max_shared_mem_bytes} bytes configured"
  330. )
  331. raise RuntimeError(
  332. f"Kernel requires {shared_mem} bytes of shared memory (>= 48KB), "
  333. f"but {configured_msg}. "
  334. "Call kernel.set_shared_memory_config(shared_mem) after compilation "
  335. "and before launching the kernel."
  336. )
  337. _check_cuda(
  338. libcuda.cuLaunchKernel(
  339. self.func,
  340. grid[0],
  341. grid[1],
  342. grid[2],
  343. block[0],
  344. block[1],
  345. block[2],
  346. shared_mem,
  347. stream._as_parameter_,
  348. c_args_array,
  349. None,
  350. )
  351. )
  352. def set_shared_memory_config(self, shared_mem_bytes: int) -> None:
  353. if shared_mem_bytes < 48 * 1024:
  354. # No configuration needed for <= 48KB, just update the value
  355. self._max_shared_mem_bytes = shared_mem_bytes
  356. return
  357. libcuda = _get_gpu_runtime_library()
  358. # Get device properties to validate against limits
  359. device_props = torch.cuda.get_device_properties()
  360. # HIP doesn't have shared_memory_per_block_optin in device properties, so we hard-code it here
  361. if torch.version.hip:
  362. # navi, CDNA1-CDNA3 allows a max of 64KB shared memory
  363. # CDNA4 allows a max of 160KB shared memory
  364. max_shared_mem = (
  365. 65536 if device_props.gcnArchName != "gfx950" else 160 * 1024
  366. )
  367. else:
  368. max_shared_mem = getattr(
  369. device_props, "shared_memory_per_block_optin", 49152
  370. )
  371. if shared_mem_bytes > max_shared_mem:
  372. raise RuntimeError(
  373. f"Requested shared memory ({shared_mem_bytes} bytes) exceeds "
  374. f"device limit ({max_shared_mem} bytes). "
  375. "Consider reducing block size or shared memory usage."
  376. )
  377. # Set the function attribute once
  378. # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
  379. cudaFuncAttributeMaxDynamicSharedMemorySize = 8
  380. _check_cuda(
  381. libcuda.cuFuncSetAttribute(
  382. self.func,
  383. cudaFuncAttributeMaxDynamicSharedMemorySize,
  384. shared_mem_bytes,
  385. )
  386. )
  387. self._max_shared_mem_bytes = shared_mem_bytes
  388. def _cuda_load_module(
  389. ptx: str | bytes, kernel_names: list[str] | None = None
  390. ) -> _CudaModule | dict[str, "_CudaKernel"]:
  391. """
  392. Loads a CUDA module from PTX code and returns a module object that can access kernels.
  393. Args:
  394. ptx (bytes or str): The PTX code to load
  395. kernel_names (list, optional): List of kernel names to extract from the module.
  396. If None, will return a module object with __getattr__.
  397. Returns:
  398. object: If kernel_names is None, returns a module object with __getattr__ to access kernels.
  399. If kernel_names is provided, returns a dict mapping kernel names to _CudaKernel objects.
  400. """
  401. # Ensure CUDA is initialized
  402. import torch.cuda
  403. # Load CUDA driver library
  404. libcuda = _get_gpu_runtime_library()
  405. # Convert PTX to bytes if it's a string
  406. if isinstance(ptx, str):
  407. ptx = ptx.encode("utf-8")
  408. # Load PTX module
  409. module = ctypes.c_void_p()
  410. # Get the current stream without directly importing torch.cuda at module level
  411. stream = torch.cuda.current_stream()
  412. with stream:
  413. _check_cuda(libcuda.cuModuleLoadData(ctypes.byref(module), ptx))
  414. if not kernel_names:
  415. return _CudaModule(module)
  416. # Return specific kernels
  417. kernels = {}
  418. for name in kernel_names:
  419. func = ctypes.c_void_p()
  420. _check_cuda(
  421. libcuda.cuModuleGetFunction(
  422. ctypes.byref(func), module, name.encode("utf-8")
  423. )
  424. )
  425. kernels[name] = _CudaKernel(func, module)
  426. return kernels
  427. def _get_device_index(
  428. device: Any, optional: bool = False, allow_cpu: bool = False
  429. ) -> int:
  430. r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
  431. If :attr:`device` is a torch.device object, returns the device index if it
  432. is a CUDA device. Note that for a CUDA device without a specified index,
  433. i.e., ``torch.device('cuda')``, this will return the current default CUDA
  434. device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
  435. CPU devices will be accepted and ``-1`` will be returned in this case.
  436. If :attr:`device` is a Python integer, it is returned as is.
  437. If :attr:`device` is ``None``, this will return the current default CUDA
  438. device if :attr:`optional` is ``True``.
  439. """
  440. if isinstance(device, int):
  441. return device
  442. if isinstance(device, str):
  443. device = torch.device(device)
  444. if isinstance(device, torch.device):
  445. if allow_cpu:
  446. if device.type not in ["cuda", "cpu"]:
  447. raise ValueError(f"Expected a cuda or cpu device, but got: {device}")
  448. elif device.type != "cuda":
  449. raise ValueError(f"Expected a cuda device, but got: {device}")
  450. if not torch.jit.is_scripting():
  451. if isinstance(device, torch.cuda.device):
  452. return device.idx
  453. return _torch_get_device_index(device, optional, allow_cpu)