| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190 |
- from __future__ import annotations
- import functools
- import hashlib
- import importlib.util
- import logging
- import os
- import shutil
- import subprocess
- import sysconfig
- import tempfile
- import re
- from types import ModuleType
- from triton.windows_utils import get_8dot3_short_path, normalize_path
- from .cache import get_cache_manager
- from .. import knobs
- if os.name == "nt":
- from triton.windows_utils import find_msvc_winsdk, find_python
- @functools.lru_cache
- def get_cc():
- cc = os.environ.get("CC")
- if cc is None:
- # clang-cl from TheRock ROCm wheels (handles HIP C headers that mix C/C++ constructs)
- cc = os.path.join(sysconfig.get_path("platlib"), "_rocm_sdk_core", "lib", "llvm", "bin", "clang-cl.exe")
- if not os.path.exists(cc):
- cc = None
- if cc is None:
- # Find and check MSVC and Windows SDK from environment variables set by Launch-VsDevShell.ps1 or VsDevCmd.bat
- cc, _, _ = find_msvc_winsdk(env_only=True)
- if cc is None:
- # Bundled TinyCC
- cc = os.path.join(sysconfig.get_paths()["platlib"], "triton", "runtime", "tcc", "tcc.exe")
- if not os.path.exists(cc):
- cc = None
- if cc is None:
- cc = shutil.which("cl")
- if cc is None:
- cc = shutil.which("gcc")
- if cc is None:
- cc = shutil.which("clang")
- if cc is None:
- raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
- return cc
- def is_tcc(cc):
- cc = os.path.basename(cc).lower()
- return cc == "tcc" or cc == "tcc.exe"
- def is_msvc(cc):
- cc = os.path.basename(cc).lower()
- return cc == "cl" or cc == "cl.exe"
- def is_clang_cl(cc):
- cc = os.path.basename(cc).lower()
- return cc == "clang-cl" or cc == "clang-cl.exe"
- def is_clang(cc):
- cc = os.path.basename(cc).lower()
- return cc == "clang" or cc == "clang.exe"
- def _cc_cmd(cc: str, src: str, out: str, include_dirs: list[str], library_dirs: list[str], libraries: list[str],
- ccflags: list[str]) -> list[str]:
- if is_msvc(cc) or is_clang_cl(cc):
- out_base = os.path.splitext(out)[0]
- cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/wd4819", "/std:c11"]
- cc_cmd += [f"/I{dir}" for dir in include_dirs if dir is not None]
- cc_cmd += [f"/Fo{out_base + '.obj'}"]
- cc_cmd += ["/link"]
- cc_cmd += [f"/LIBPATH:{dir}" for dir in library_dirs]
- cc_cmd += [f"{lib}.lib" for lib in libraries]
- cc_cmd += [f"/OUT:{out}"]
- cc_cmd += [f"/IMPLIB:{out_base + '.lib'}"]
- cc_cmd += [f"/PDB:{out_base + '.pdb'}"]
- else:
- # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047
- cc_cmd = [cc, src, "-O3", "-shared", "-Wno-psabi", "-o", out]
- if not (os.name == "nt" and is_clang(cc)):
- # Clang does not support -fPIC on Windows
- cc_cmd += ["-fPIC"]
- if is_tcc(cc):
- cc_cmd += ["-D_Py_USE_GCC_BUILTIN_ATOMICS"]
- cc_cmd += [_library_flag(lib) for lib in libraries]
- cc_cmd += [f"-L{dir}" for dir in library_dirs]
- cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
- cc_cmd += ccflags
- return cc_cmd
- def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], libraries: list[str],
- ccflags: list[str]) -> str:
- if impl := knobs.build.impl:
- return impl(name, src, srcdir, library_dirs, include_dirs, libraries)
- cc = get_cc()
- if is_msvc(cc):
- # MSVC does not support UNC path with \\?\ prefix. We convert it to 8.3 short path.
- src = get_8dot3_short_path(src)
- srcdir = get_8dot3_short_path(srcdir)
- suffix = sysconfig.get_config_var('EXT_SUFFIX')
- so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
- scheme = sysconfig.get_default_scheme()
- # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
- # path changes to include 'local'. This change is required to use triton with system-wide python.
- if scheme == 'posix_local':
- scheme = 'posix_prefix'
- py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
- custom_backend_dirs = knobs.build.backend_dirs
- # Don't append in place
- include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs]
- if os.name == "nt":
- library_dirs = library_dirs + find_python()
- version = sysconfig.get_python_version().replace(".", "")
- if sysconfig.get_config_var("Py_GIL_DISABLED"):
- version += "t"
- libraries = libraries + [f"python{version}"]
- if is_msvc(cc) or is_clang_cl(cc):
- _, msvc_winsdk_inc_dirs, msvc_winsdk_lib_dirs = find_msvc_winsdk()
- include_dirs = include_dirs + msvc_winsdk_inc_dirs
- library_dirs = library_dirs + msvc_winsdk_lib_dirs
- cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries, ccflags)
- try:
- subprocess.check_call(cc_cmd)
- except Exception as e:
- print("Failed to compile. cc_cmd:", cc_cmd)
- raise e
- return so
- def _library_flag(lib: str) -> str:
- # Match .so files with optional version numbers (e.g., .so, .so.1, .so.513.50.1)
- if re.search(r'\.so(\.\d+)*$', lib) or lib.endswith(".a"):
- return f"-l:{lib}"
- return f"-l{lib}"
- @functools.lru_cache
- def platform_key() -> str:
- from platform import machine, system, architecture
- return ",".join([machine(), system(), *architecture()])
- def _load_module_from_path(name: str, path: str) -> ModuleType:
- # Loading module with relative path may cause error. `normalize_path` normalizes to absolute path.
- path = normalize_path(path)
- spec = importlib.util.spec_from_file_location(name, path)
- if not spec or not spec.loader:
- raise RuntimeError(f"Failed to load newly compiled {name} from {path}")
- mod = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(mod)
- return mod
- def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None,
- include_dirs: list[str] | None = None, libraries: list[str] | None = None,
- ccflags: list[str] | None = None) -> ModuleType:
- key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
- cache = get_cache_manager(key)
- suffix = sysconfig.get_config_var("EXT_SUFFIX")
- cache_path = cache.get_file(f"{name}{suffix}")
- if cache_path is not None:
- try:
- return _load_module_from_path(name, cache_path)
- except (RuntimeError, ImportError):
- log = logging.getLogger(__name__)
- log.warning(f"Triton cache error: compiled module {name}.so could not be loaded")
- with tempfile.TemporaryDirectory() as tmpdir:
- tmpdir = normalize_path(tmpdir)
- src_path = os.path.join(tmpdir, name + ".c")
- with open(src_path, "w") as f:
- f.write(src)
- so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [], ccflags or [])
- with open(so, "rb") as f:
- cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
- return _load_module_from_path(name, cache_path)
|