| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546 |
- import ctypes
- import sys
- from typing import Any
- import torch
- # The _get_device_index has been moved to torch.utils._get_device_index
- from torch._utils import _get_device_index as _torch_get_device_index
- def _get_hip_runtime_library() -> ctypes.CDLL:
- # If ROCm python packages are available, query the OS-independent absolute
- # path to the library provided by those packages, including any version suffix.
- # See https://github.com/ROCm/TheRock/blob/main/docs/packaging/python_packaging.md#dynamic-library-resolution
- try:
- # pyrefly: ignore [import-error, missing-import]
- import rocm_sdk
- lib = ctypes.CDLL(str(rocm_sdk.find_libraries("amdhip64")[0]))
- except (ImportError, IndexError):
- if sys.platform == "win32":
- lib = ctypes.CDLL(f"amdhip64_{torch.version.hip[0]}.dll")
- else: # Unix-based systems
- lib = ctypes.CDLL("libamdhip64.so")
- lib.cuGetErrorString = lib.hipGetErrorString # type: ignore[attr-defined]
- lib.cuModuleLoadData = lib.hipModuleLoadData # type: ignore[attr-defined]
- lib.cuModuleGetFunction = lib.hipModuleGetFunction # type: ignore[attr-defined]
- lib.cuLaunchKernel = lib.hipModuleLaunchKernel # type: ignore[attr-defined]
- lib.cuFuncSetAttribute = lib.hipFuncSetAttribute # type: ignore[attr-defined]
- return lib
- def _get_cuda_library() -> ctypes.CDLL:
- if sys.platform == "win32":
- return ctypes.CDLL("nvcuda.dll")
- else: # Unix-based systems
- return ctypes.CDLL("libcuda.so.1")
- # Load GPU driver runtime
- def _get_gpu_runtime_library() -> ctypes.CDLL:
- if torch.version.hip:
- return _get_hip_runtime_library()
- else:
- return _get_cuda_library()
- # Helper: check CUDA errors
- def _check_cuda(result: int) -> None:
- if result == 0:
- return
- err_str = ctypes.c_char_p()
- libcuda = _get_gpu_runtime_library() # Get reference to CUDA library
- libcuda.cuGetErrorString(result, ctypes.byref(err_str))
- error_message = (
- err_str.value.decode() if err_str.value is not None else "Unknown CUDA error"
- )
- raise RuntimeError(f"CUDA error: {error_message}")
- def _get_hiprtc_library() -> ctypes.CDLL:
- try:
- # pyrefly: ignore [import-error, missing-import]
- import rocm_sdk
- lib = ctypes.CDLL(str(rocm_sdk.find_libraries("hiprtc")[0]))
- except (ImportError, IndexError):
- if sys.platform == "win32":
- version_str = "".join(
- ["0", torch.version.hip[0], "0", torch.version.hip[2]]
- )
- lib = ctypes.CDLL(f"hiprtc{version_str}.dll")
- else:
- lib = ctypes.CDLL("libhiprtc.so")
- # Provide aliases for HIP RTC functions to match NVRTC API
- lib.nvrtcGetErrorString = lib.hiprtcGetErrorString # type: ignore[attr-defined]
- lib.nvrtcCreateProgram = lib.hiprtcCreateProgram # type: ignore[attr-defined]
- lib.nvrtcDestroyProgram = lib.hiprtcDestroyProgram # type: ignore[attr-defined]
- lib.nvrtcCompileProgram = lib.hiprtcCompileProgram # type: ignore[attr-defined]
- lib.nvrtcGetCUBINSize = lib.hiprtcGetCodeSize # type: ignore[attr-defined]
- lib.nvrtcGetCUBIN = lib.hiprtcGetCode # type: ignore[attr-defined]
- lib.nvrtcGetProgramLogSize = lib.hiprtcGetProgramLogSize # type: ignore[attr-defined]
- lib.nvrtcGetProgramLog = lib.hiprtcGetProgramLog # type: ignore[attr-defined]
- lib.nvrtcAddNameExpression = lib.hiprtcAddNameExpression # type: ignore[attr-defined]
- lib.nvrtcGetLoweredName = lib.hiprtcGetLoweredName # type: ignore[attr-defined]
- return lib
- def _get_nvrtc_library() -> ctypes.CDLL:
- major_version = int(torch.version.cuda.split(".")[0]) # type: ignore[union-attr]
- if sys.platform == "win32":
- nvrtc_libs = [
- f"nvrtc64_{major_version}0_0.dll",
- ]
- else:
- nvrtc_libs = [
- f"libnvrtc.so.{major_version}",
- "libnvrtc.so", # Fallback to unversioned
- ]
- for lib_name in nvrtc_libs:
- try:
- return ctypes.CDLL(lib_name)
- except OSError:
- continue
- raise OSError("Could not find any NVRTC library")
- def _get_gpu_rtc_library() -> ctypes.CDLL:
- # Since PyTorch already loads the GPU RTC library, we can use the system library
- # which should be compatible with PyTorch's version
- if torch.version.hip:
- return _get_hiprtc_library()
- else:
- return _get_nvrtc_library()
- def _get_gpu_rtc_compatible_flags() -> list[str]:
- """
- Get HIPCC/NVCC flags that are compatible with NVRTC compilation.
- Returns:
- List of HIPCC/NVCC flags that can be safely used with NVRTC.
- """
- from torch.utils.cpp_extension import COMMON_HIPCC_FLAGS, COMMON_NVCC_FLAGS
- nvrtc_unsupported_flags = {
- "--expt-relaxed-constexpr",
- }
- # Filter out unsupported flags
- compatible_flags = [
- flag for flag in COMMON_NVCC_FLAGS if flag not in nvrtc_unsupported_flags
- ]
- if torch.version.hip:
- compatible_flags.extend(COMMON_HIPCC_FLAGS)
- return compatible_flags
- def _nvrtc_compile(
- kernel_source: str,
- kernel_name: str,
- compute_capability: str | None = None,
- cuda_include_dirs: list | None = None,
- nvcc_options: list | None = None,
- auto_pch: bool = False,
- ) -> tuple[bytes, str]:
- """
- Compiles a CUDA kernel using NVRTC and returns the PTX code.
- Args:
- kernel_source (str): The CUDA kernel source code as a string
- kernel_name (str): The name of the kernel function to compile
- compute_capability (str, None): The compute capability to target (e.g., "86").
- If None, will detect from current device.
- cuda_include_dirs (list, None): List of directories containing CUDA headers
- nvcc_options (list, None): Additional options to pass to NVRTC
- auto_pch (bool): Enable automatic precompiled headers (CUDA 12.8+)
- Returns:
- Tuple[bytes, str]: The compiled PTX code and mangled kernel name
- """
- # Ensure CUDA is initialized
- import torch.cuda
- # Load NVRTC library
- libnvrtc = _get_gpu_rtc_library()
- # NVRTC constants
- NVRTC_SUCCESS = 0
- # Helper: check NVRTC errors
- def check_nvrtc(result: int) -> None:
- if result != NVRTC_SUCCESS:
- err_str = ctypes.c_char_p()
- libnvrtc.nvrtcGetErrorString(result, ctypes.byref(err_str))
- error_message = (
- err_str.value.decode()
- if err_str.value is not None
- else "Unknown CUDA error"
- )
- raise RuntimeError(f"CUDA error: {error_message}")
- # Convert source to bytes
- source_bytes = kernel_source.encode("utf-8")
- # Get compute capability if not provided
- if compute_capability is None:
- props = torch.cuda.get_device_properties(torch.cuda.current_device())
- if torch.version.hip:
- compute_capability = f"{props.gcnArchName}"
- else:
- compute_capability = f"{props.major}{props.minor}"
- # Prepare compilation options
- options = []
- if torch.version.hip:
- options.append(f"--offload-arch={compute_capability}".encode())
- else:
- options.append(f"--gpu-architecture=sm_{compute_capability}".encode())
- # Auto-detect and add CUDA include paths
- from torch.utils.cpp_extension import include_paths
- cuda_include_paths = include_paths("cuda")
- for cuda_path in cuda_include_paths:
- options.append(f"-I{cuda_path}".encode())
- # Add custom include directories
- if cuda_include_dirs:
- for directory in cuda_include_dirs:
- options.append(f"-I{directory}".encode())
- # Enable automatic precompiled headers (CUDA 12.8+)
- if auto_pch:
- if str(torch.version.cuda) < "12.8":
- raise AssertionError(f"PCH requires CUDA 12.8+, got {torch.version.cuda}")
- if nvcc_options is None:
- nvcc_options = []
- nvcc_options.append("--pch")
- # Add custom NVCC options
- if nvcc_options:
- for option in nvcc_options:
- options.append(option.encode("utf-8"))
- nvrtc_compatible_flags = _get_gpu_rtc_compatible_flags()
- options.extend([flag.encode("utf-8") for flag in nvrtc_compatible_flags])
- # Convert options to C array
- num_options = len(options)
- options_array = (ctypes.c_char_p * num_options)(*options)
- # Create program
- prog = ctypes.c_void_p()
- check_nvrtc(
- libnvrtc.nvrtcCreateProgram(
- ctypes.byref(prog),
- source_bytes,
- f"{kernel_name}.cu".encode(),
- 0,
- None,
- None,
- )
- )
- # Add kernel name, which can be a template expression
- c_kernel_name = kernel_name.encode("utf-8")
- check_nvrtc(libnvrtc.nvrtcAddNameExpression(prog, c_kernel_name))
- # Compile program
- res = libnvrtc.nvrtcCompileProgram(prog, num_options, options_array)
- # Handle compilation errors
- if res != NVRTC_SUCCESS:
- # Get log
- log_size = ctypes.c_size_t()
- libnvrtc.nvrtcGetProgramLogSize(prog, ctypes.byref(log_size))
- log = ctypes.create_string_buffer(log_size.value)
- libnvrtc.nvrtcGetProgramLog(prog, log)
- raise RuntimeError(f"Kernel compilation failed:\n{log.value.decode()}")
- # Get binary
- binary_size = ctypes.c_size_t()
- check_nvrtc(libnvrtc.nvrtcGetCUBINSize(prog, ctypes.byref(binary_size)))
- binary = ctypes.create_string_buffer(binary_size.value)
- check_nvrtc(libnvrtc.nvrtcGetCUBIN(prog, binary))
- # Get mangled name
- c_mangled_name = ctypes.c_char_p()
- check_nvrtc(
- libnvrtc.nvrtcGetLoweredName(prog, c_kernel_name, ctypes.byref(c_mangled_name))
- )
- if c_mangled_name.value is not None:
- mangled_name = c_mangled_name.value.decode() # make a copy
- else:
- mangled_name = ""
- libnvrtc.nvrtcDestroyProgram(ctypes.byref(prog))
- # For some reason, ".value" causes the string to be truncated,
- # likely due to the presence of '\0' in the string. So we use .raw instead.
- return binary.raw, mangled_name
- class _CudaModule:
- def __init__(self, module: ctypes.c_void_p) -> None:
- self._module = module
- self._kernels: dict[str, _CudaKernel] = {}
- def __getattr__(self, name: str) -> "_CudaKernel":
- if name in self._kernels:
- return self._kernels[name]
- # Import the CUDA library inside the method
- # pyrefly: ignore [missing-module-attribute]
- from torch.cuda._utils import _get_gpu_runtime_library
- libcuda = _get_gpu_runtime_library()
- func = ctypes.c_void_p()
- try:
- _check_cuda(
- libcuda.cuModuleGetFunction(
- ctypes.byref(func), self._module, name.encode("utf-8")
- )
- )
- kernel = _CudaKernel(func, self._module)
- self._kernels[name] = kernel
- return kernel
- except RuntimeError as err:
- raise AttributeError(f"No kernel named '{name}' in this module") from err
- class _CudaKernel:
- """
- Represents a compiled CUDA kernel that can be called with PyTorch tensors.
- """
- def __init__(self, func: ctypes.c_void_p, module: ctypes.c_void_p) -> None:
- self.func = func
- self.module = module
- self._max_shared_mem_bytes = 0
- def __call__(
- self,
- grid: tuple[int, int, int] = (1, 1, 1),
- block: tuple[int, int, int] = (1, 1, 1),
- args: list | None = None,
- shared_mem: int = 0,
- stream: Any | None = None,
- ) -> None:
- """
- Call the compiled CUDA kernel
- Args:
- grid (tuple): Grid dimensions (grid_x, grid_y, grid_z)
- block (tuple): Block dimensions (block_x, block_y, block_z)
- args (list): List of arguments to pass to the kernel.
- PyTorch tensor arguments will be automatically converted to pointers.
- shared_mem (int): Shared memory size in bytes
- stream (torch.cuda.Stream): CUDA stream to use. If None, uses current stream.
- """
- import torch
- libcuda = torch.cuda._utils._get_gpu_runtime_library()
- if not args:
- args = []
- # Process arguments and convert tensors to pointers
- processed_args: list[ctypes.c_void_p] = []
- c_args = []
- for arg in args:
- if isinstance(arg, torch.Tensor):
- if not arg.is_cuda and not (arg.is_cpu and arg.is_pinned()):
- raise ValueError(
- "All tensor arguments must be CUDA tensors or pinned CPU tensors"
- )
- # Get pointer to tensor data
- ptr = ctypes.c_void_p(arg.data_ptr())
- processed_args.append(ptr)
- c_args.append(ctypes.byref(ptr))
- elif isinstance(arg, int):
- # Convert integers to C int
- c_int = ctypes.c_int(arg)
- # Store the C int for reference keeping, not in processed_args
- c_args.append(ctypes.byref(c_int))
- elif isinstance(arg, float):
- # Python floats are doubles - use double by default
- c_double = ctypes.c_double(arg)
- # Store the C double for reference keeping, not in processed_args
- c_args.append(ctypes.byref(c_double))
- else:
- raise TypeError(f"Unsupported argument type: {type(arg)}")
- # Convert to array of void pointers
- c_args_array = (ctypes.c_void_p * len(c_args))()
- for i, arg in enumerate(c_args):
- c_args_array[i] = ctypes.cast(arg, ctypes.c_void_p)
- # Get the stream
- if stream is None:
- # Defer import to avoid circular imports
- import torch.cuda
- stream = torch.cuda.current_stream()
- # Check if kernel requires large shared memory but hasn't been configured
- if shared_mem >= 48 * 1024 and (
- self._max_shared_mem_bytes == 0 or shared_mem > self._max_shared_mem_bytes
- ):
- configured_msg = (
- "not configured"
- if self._max_shared_mem_bytes == 0
- else f"only {self._max_shared_mem_bytes} bytes configured"
- )
- raise RuntimeError(
- f"Kernel requires {shared_mem} bytes of shared memory (>= 48KB), "
- f"but {configured_msg}. "
- "Call kernel.set_shared_memory_config(shared_mem) after compilation "
- "and before launching the kernel."
- )
- _check_cuda(
- libcuda.cuLaunchKernel(
- self.func,
- grid[0],
- grid[1],
- grid[2],
- block[0],
- block[1],
- block[2],
- shared_mem,
- stream._as_parameter_,
- c_args_array,
- None,
- )
- )
- def set_shared_memory_config(self, shared_mem_bytes: int) -> None:
- if shared_mem_bytes < 48 * 1024:
- # No configuration needed for <= 48KB, just update the value
- self._max_shared_mem_bytes = shared_mem_bytes
- return
- libcuda = _get_gpu_runtime_library()
- # Get device properties to validate against limits
- device_props = torch.cuda.get_device_properties()
- # HIP doesn't have shared_memory_per_block_optin in device properties, so we hard-code it here
- if torch.version.hip:
- # navi, CDNA1-CDNA3 allows a max of 64KB shared memory
- # CDNA4 allows a max of 160KB shared memory
- max_shared_mem = (
- 65536 if device_props.gcnArchName != "gfx950" else 160 * 1024
- )
- else:
- max_shared_mem = getattr(
- device_props, "shared_memory_per_block_optin", 49152
- )
- if shared_mem_bytes > max_shared_mem:
- raise RuntimeError(
- f"Requested shared memory ({shared_mem_bytes} bytes) exceeds "
- f"device limit ({max_shared_mem} bytes). "
- "Consider reducing block size or shared memory usage."
- )
- # Set the function attribute once
- # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
- cudaFuncAttributeMaxDynamicSharedMemorySize = 8
- _check_cuda(
- libcuda.cuFuncSetAttribute(
- self.func,
- cudaFuncAttributeMaxDynamicSharedMemorySize,
- shared_mem_bytes,
- )
- )
- self._max_shared_mem_bytes = shared_mem_bytes
- def _cuda_load_module(
- ptx: str | bytes, kernel_names: list[str] | None = None
- ) -> _CudaModule | dict[str, "_CudaKernel"]:
- """
- Loads a CUDA module from PTX code and returns a module object that can access kernels.
- Args:
- ptx (bytes or str): The PTX code to load
- kernel_names (list, optional): List of kernel names to extract from the module.
- If None, will return a module object with __getattr__.
- Returns:
- object: If kernel_names is None, returns a module object with __getattr__ to access kernels.
- If kernel_names is provided, returns a dict mapping kernel names to _CudaKernel objects.
- """
- # Ensure CUDA is initialized
- import torch.cuda
- # Load CUDA driver library
- libcuda = _get_gpu_runtime_library()
- # Convert PTX to bytes if it's a string
- if isinstance(ptx, str):
- ptx = ptx.encode("utf-8")
- # Load PTX module
- module = ctypes.c_void_p()
- # Get the current stream without directly importing torch.cuda at module level
- stream = torch.cuda.current_stream()
- with stream:
- _check_cuda(libcuda.cuModuleLoadData(ctypes.byref(module), ptx))
- if not kernel_names:
- return _CudaModule(module)
- # Return specific kernels
- kernels = {}
- for name in kernel_names:
- func = ctypes.c_void_p()
- _check_cuda(
- libcuda.cuModuleGetFunction(
- ctypes.byref(func), module, name.encode("utf-8")
- )
- )
- kernels[name] = _CudaKernel(func, module)
- return kernels
- def _get_device_index(
- device: Any, optional: bool = False, allow_cpu: bool = False
- ) -> int:
- r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
- If :attr:`device` is a torch.device object, returns the device index if it
- is a CUDA device. Note that for a CUDA device without a specified index,
- i.e., ``torch.device('cuda')``, this will return the current default CUDA
- device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
- CPU devices will be accepted and ``-1`` will be returned in this case.
- If :attr:`device` is a Python integer, it is returned as is.
- If :attr:`device` is ``None``, this will return the current default CUDA
- device if :attr:`optional` is ``True``.
- """
- if isinstance(device, int):
- return device
- if isinstance(device, str):
- device = torch.device(device)
- if isinstance(device, torch.device):
- if allow_cpu:
- if device.type not in ["cuda", "cpu"]:
- raise ValueError(f"Expected a cuda or cpu device, but got: {device}")
- elif device.type != "cuda":
- raise ValueError(f"Expected a cuda device, but got: {device}")
- if not torch.jit.is_scripting():
- if isinstance(device, torch.cuda.device):
- return device.idx
- return _torch_get_device_index(device, optional, allow_cpu)
|