build.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. from __future__ import annotations
  2. import functools
  3. import hashlib
  4. import importlib.util
  5. import logging
  6. import os
  7. import shutil
  8. import subprocess
  9. import sysconfig
  10. import tempfile
  11. import re
  12. from types import ModuleType
  13. from triton.windows_utils import get_8dot3_short_path, normalize_path
  14. from .cache import get_cache_manager
  15. from .. import knobs
  16. if os.name == "nt":
  17. from triton.windows_utils import find_msvc_winsdk, find_python
  18. @functools.lru_cache
  19. def get_cc():
  20. cc = os.environ.get("CC")
  21. if cc is None:
  22. # clang-cl from TheRock ROCm wheels (handles HIP C headers that mix C/C++ constructs)
  23. cc = os.path.join(sysconfig.get_path("platlib"), "_rocm_sdk_core", "lib", "llvm", "bin", "clang-cl.exe")
  24. if not os.path.exists(cc):
  25. cc = None
  26. if cc is None:
  27. # Find and check MSVC and Windows SDK from environment variables set by Launch-VsDevShell.ps1 or VsDevCmd.bat
  28. cc, _, _ = find_msvc_winsdk(env_only=True)
  29. if cc is None:
  30. # Bundled TinyCC
  31. cc = os.path.join(sysconfig.get_paths()["platlib"], "triton", "runtime", "tcc", "tcc.exe")
  32. if not os.path.exists(cc):
  33. cc = None
  34. if cc is None:
  35. cc = shutil.which("cl")
  36. if cc is None:
  37. cc = shutil.which("gcc")
  38. if cc is None:
  39. cc = shutil.which("clang")
  40. if cc is None:
  41. raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
  42. return cc
  43. def is_tcc(cc):
  44. cc = os.path.basename(cc).lower()
  45. return cc == "tcc" or cc == "tcc.exe"
  46. def is_msvc(cc):
  47. cc = os.path.basename(cc).lower()
  48. return cc == "cl" or cc == "cl.exe"
  49. def is_clang_cl(cc):
  50. cc = os.path.basename(cc).lower()
  51. return cc == "clang-cl" or cc == "clang-cl.exe"
  52. def is_clang(cc):
  53. cc = os.path.basename(cc).lower()
  54. return cc == "clang" or cc == "clang.exe"
  55. def _cc_cmd(cc: str, src: str, out: str, include_dirs: list[str], library_dirs: list[str], libraries: list[str],
  56. ccflags: list[str]) -> list[str]:
  57. if is_msvc(cc) or is_clang_cl(cc):
  58. out_base = os.path.splitext(out)[0]
  59. cc_cmd = [cc, src, "/nologo", "/O2", "/LD", "/wd4819", "/std:c11"]
  60. cc_cmd += [f"/I{dir}" for dir in include_dirs if dir is not None]
  61. cc_cmd += [f"/Fo{out_base + '.obj'}"]
  62. cc_cmd += ["/link"]
  63. cc_cmd += [f"/LIBPATH:{dir}" for dir in library_dirs]
  64. cc_cmd += [f"{lib}.lib" for lib in libraries]
  65. cc_cmd += [f"/OUT:{out}"]
  66. cc_cmd += [f"/IMPLIB:{out_base + '.lib'}"]
  67. cc_cmd += [f"/PDB:{out_base + '.pdb'}"]
  68. else:
  69. # for -Wno-psabi, see https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111047
  70. cc_cmd = [cc, src, "-O3", "-shared", "-Wno-psabi", "-o", out]
  71. if not (os.name == "nt" and is_clang(cc)):
  72. # Clang does not support -fPIC on Windows
  73. cc_cmd += ["-fPIC"]
  74. if is_tcc(cc):
  75. cc_cmd += ["-D_Py_USE_GCC_BUILTIN_ATOMICS"]
  76. cc_cmd += [_library_flag(lib) for lib in libraries]
  77. cc_cmd += [f"-L{dir}" for dir in library_dirs]
  78. cc_cmd += [f"-I{dir}" for dir in include_dirs if dir is not None]
  79. cc_cmd += ccflags
  80. return cc_cmd
  81. def _build(name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str], libraries: list[str],
  82. ccflags: list[str]) -> str:
  83. if impl := knobs.build.impl:
  84. return impl(name, src, srcdir, library_dirs, include_dirs, libraries)
  85. cc = get_cc()
  86. if is_msvc(cc):
  87. # MSVC does not support UNC path with \\?\ prefix. We convert it to 8.3 short path.
  88. src = get_8dot3_short_path(src)
  89. srcdir = get_8dot3_short_path(srcdir)
  90. suffix = sysconfig.get_config_var('EXT_SUFFIX')
  91. so = os.path.join(srcdir, '{name}{suffix}'.format(name=name, suffix=suffix))
  92. scheme = sysconfig.get_default_scheme()
  93. # 'posix_local' is a custom scheme on Debian. However, starting Python 3.10, the default install
  94. # path changes to include 'local'. This change is required to use triton with system-wide python.
  95. if scheme == 'posix_local':
  96. scheme = 'posix_prefix'
  97. py_include_dir = sysconfig.get_paths(scheme=scheme)["include"]
  98. custom_backend_dirs = knobs.build.backend_dirs
  99. # Don't append in place
  100. include_dirs = include_dirs + [srcdir, py_include_dir, *custom_backend_dirs]
  101. if os.name == "nt":
  102. library_dirs = library_dirs + find_python()
  103. version = sysconfig.get_python_version().replace(".", "")
  104. if sysconfig.get_config_var("Py_GIL_DISABLED"):
  105. version += "t"
  106. libraries = libraries + [f"python{version}"]
  107. if is_msvc(cc) or is_clang_cl(cc):
  108. _, msvc_winsdk_inc_dirs, msvc_winsdk_lib_dirs = find_msvc_winsdk()
  109. include_dirs = include_dirs + msvc_winsdk_inc_dirs
  110. library_dirs = library_dirs + msvc_winsdk_lib_dirs
  111. cc_cmd = _cc_cmd(cc, src, so, include_dirs, library_dirs, libraries, ccflags)
  112. try:
  113. subprocess.check_call(cc_cmd)
  114. except Exception as e:
  115. print("Failed to compile. cc_cmd:", cc_cmd)
  116. raise e
  117. return so
  118. def _library_flag(lib: str) -> str:
  119. # Match .so files with optional version numbers (e.g., .so, .so.1, .so.513.50.1)
  120. if re.search(r'\.so(\.\d+)*$', lib) or lib.endswith(".a"):
  121. return f"-l:{lib}"
  122. return f"-l{lib}"
  123. @functools.lru_cache
  124. def platform_key() -> str:
  125. from platform import machine, system, architecture
  126. return ",".join([machine(), system(), *architecture()])
  127. def _load_module_from_path(name: str, path: str) -> ModuleType:
  128. # Loading module with relative path may cause error. `normalize_path` normalizes to absolute path.
  129. path = normalize_path(path)
  130. spec = importlib.util.spec_from_file_location(name, path)
  131. if not spec or not spec.loader:
  132. raise RuntimeError(f"Failed to load newly compiled {name} from {path}")
  133. mod = importlib.util.module_from_spec(spec)
  134. spec.loader.exec_module(mod)
  135. return mod
  136. def compile_module_from_src(src: str, name: str, library_dirs: list[str] | None = None,
  137. include_dirs: list[str] | None = None, libraries: list[str] | None = None,
  138. ccflags: list[str] | None = None) -> ModuleType:
  139. key = hashlib.sha256((src + platform_key()).encode("utf-8")).hexdigest()
  140. cache = get_cache_manager(key)
  141. suffix = sysconfig.get_config_var("EXT_SUFFIX")
  142. cache_path = cache.get_file(f"{name}{suffix}")
  143. if cache_path is not None:
  144. try:
  145. return _load_module_from_path(name, cache_path)
  146. except (RuntimeError, ImportError):
  147. log = logging.getLogger(__name__)
  148. log.warning(f"Triton cache error: compiled module {name}.so could not be loaded")
  149. with tempfile.TemporaryDirectory() as tmpdir:
  150. tmpdir = normalize_path(tmpdir)
  151. src_path = os.path.join(tmpdir, name + ".c")
  152. with open(src_path, "w") as f:
  153. f.write(src)
  154. so = _build(name, src_path, tmpdir, library_dirs or [], include_dirs or [], libraries or [], ccflags or [])
  155. with open(so, "rb") as f:
  156. cache_path = cache.put(f.read(), f"{name}{suffix}", binary=True)
  157. return _load_module_from_path(name, cache_path)