windows_utils.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. import ctypes
  2. import functools
  3. import os
  4. import re
  5. import subprocess
  6. import sys
  7. import sysconfig
  8. import warnings
  9. import winreg
  10. from collections.abc import Iterable
  11. from functools import partial
  12. from glob import glob
  13. from pathlib import Path
  14. from typing import Callable, Optional
  15. def find_in_program_files(rel_path: str) -> Optional[Path]:
  16. program_files = os.getenv("ProgramFiles(x86)", r"C:\Program Files (x86)")
  17. path = Path(program_files) / rel_path
  18. if path.exists():
  19. return path
  20. program_files = os.getenv("ProgramW6432", r"C:\Program Files")
  21. path = Path(program_files) / rel_path
  22. if path.exists():
  23. return path
  24. return None
  25. def parse_version(s: str, prefix: str = "") -> Optional[tuple[int, ...]]:
  26. s = s.removeprefix(prefix)
  27. try:
  28. return tuple(int(x) for x in s.split("."))
  29. except ValueError:
  30. return None
  31. def unparse_version(t: Iterable[int], prefix: str = "") -> str:
  32. return prefix + ".".join([str(x) for x in t])
  33. def max_version(
  34. versions: Iterable[str],
  35. prefix: str = "",
  36. check: Callable[[str], bool] = lambda x: True,
  37. ) -> Optional[str]:
  38. versions = [x for x in versions if check(x)]
  39. versions = [parse_version(x, prefix) for x in versions]
  40. versions = [x for x in versions if x is not None]
  41. if not versions:
  42. return None
  43. version = unparse_version(max(versions), prefix)
  44. return version
  45. def check_msvc(msvc_base_path: Path, version: str) -> bool:
  46. return all(x.exists() for x in [
  47. msvc_base_path / version / "bin" / "Hostx64" / "x64" / "cl.exe",
  48. msvc_base_path / version / "include" / "vcruntime.h",
  49. msvc_base_path / version / "lib" / "x64" / "vcruntime.lib",
  50. ])
  51. def find_msvc_env() -> tuple[Optional[Path], Optional[str]]:
  52. msvc_base_path = os.getenv("VCINSTALLDIR")
  53. if msvc_base_path is None:
  54. return None, None
  55. msvc_base_path = Path(msvc_base_path) / "Tools" / "MSVC"
  56. version = os.getenv("VCToolsVersion")
  57. if not check_msvc(msvc_base_path, version):
  58. warnings.warn(f"Environment variables VCINSTALLDIR = {os.getenv('VCINSTALLDIR')}, "
  59. f"VCToolsVersion = {os.getenv('VCToolsVersion')} are set, "
  60. "but this MSVC installation is incomplete.")
  61. return None, None
  62. return msvc_base_path, version
  63. def find_msvc_vswhere() -> tuple[Optional[Path], Optional[str]]:
  64. vswhere_path = find_in_program_files(r"Microsoft Visual Studio\Installer\vswhere.exe")
  65. if vswhere_path is None:
  66. return None, None
  67. command = [
  68. str(vswhere_path),
  69. "-prerelease",
  70. "-products",
  71. "*",
  72. "-requires",
  73. "Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
  74. "-requires",
  75. "Microsoft.VisualStudio.Component.Windows10SDK",
  76. "-latest",
  77. "-property",
  78. "installationPath",
  79. ]
  80. try:
  81. output = subprocess.check_output(command, text=True).strip()
  82. except subprocess.CalledProcessError:
  83. return None, None
  84. msvc_base_path = Path(output) / "VC" / "Tools" / "MSVC"
  85. if not msvc_base_path.exists():
  86. return None, None
  87. version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
  88. if version is None:
  89. return None, None
  90. return msvc_base_path, version
  91. def find_msvc_envpath() -> tuple[Optional[Path], Optional[str]]:
  92. paths = os.getenv("PATH", "").split(os.pathsep)
  93. for path in paths:
  94. path = path.replace("/", "\\")
  95. match = re.compile(r".*\\VC\\Tools\\MSVC\\").match(path)
  96. if not match:
  97. continue
  98. msvc_base_path = Path(match.group(0))
  99. if not msvc_base_path.exists():
  100. continue
  101. version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
  102. if version is None:
  103. continue
  104. return msvc_base_path, version
  105. return None, None
  106. def find_msvc_hardcoded() -> tuple[Optional[Path], Optional[str]]:
  107. vs_path = find_in_program_files("Microsoft Visual Studio")
  108. if vs_path is None:
  109. return None, None
  110. paths = glob(str(vs_path / "*" / "*" / "VC" / "Tools" / "MSVC"))
  111. # First try the highest version
  112. paths = sorted(paths)[::-1]
  113. for msvc_base_path in paths:
  114. msvc_base_path = Path(msvc_base_path)
  115. version = max_version(os.listdir(msvc_base_path), check=partial(check_msvc, msvc_base_path))
  116. if version is None:
  117. continue
  118. return msvc_base_path, version
  119. return None, None
  120. def find_msvc(env_only: bool) -> tuple[Optional[str], list[str], list[str]]:
  121. if env_only:
  122. fs = [find_msvc_env]
  123. else:
  124. fs = [
  125. find_msvc_env,
  126. find_msvc_vswhere,
  127. find_msvc_envpath,
  128. find_msvc_hardcoded,
  129. ]
  130. for f in fs:
  131. msvc_base_path, version = f()
  132. if msvc_base_path:
  133. return (
  134. str(msvc_base_path / version / "bin" / "Hostx64" / "x64" / "cl.exe"),
  135. [str(msvc_base_path / version / "include")],
  136. [str(msvc_base_path / version / "lib" / "x64")],
  137. )
  138. if not env_only:
  139. warnings.warn("Failed to find MSVC.")
  140. return None, [], []
  141. def check_winsdk(winsdk_base_path: Path, version: str) -> bool:
  142. return all(x.exists() for x in [
  143. winsdk_base_path / "Include" / version / "ucrt" / "stdlib.h",
  144. winsdk_base_path / "Lib" / version / "ucrt" / "x64" / "ucrt.lib",
  145. ])
  146. def find_winsdk_env() -> tuple[Optional[Path], Optional[str]]:
  147. winsdk_base_path = os.getenv("WindowsSdkDir")
  148. if winsdk_base_path is None:
  149. return None, None
  150. winsdk_base_path = Path(winsdk_base_path)
  151. version = os.getenv("WindowsSDKVersion")
  152. if version is None:
  153. version = os.getenv("WindowsSDKVer")
  154. if version is None:
  155. warnings.warn(f"Environment variable WindowsSdkDir = {winsdk_base_path}, "
  156. "but WindowsSDKVersion (or WindowsSDKVer) is not set.")
  157. return None, None
  158. version = version.rstrip("\\")
  159. if not check_winsdk(winsdk_base_path, version):
  160. warnings.warn(f"Environment variables WindowsSdkDir = {winsdk_base_path}, "
  161. f"WindowsSDKVersion (or WindowsSDKVer) = {version} are set, "
  162. "but this Windows SDK installation is incomplete.")
  163. return None, None
  164. return winsdk_base_path, version
  165. def find_winsdk_registry() -> tuple[Optional[Path], Optional[str]]:
  166. try:
  167. reg = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE)
  168. key = winreg.OpenKeyEx(reg, r"SOFTWARE\WOW6432Node\Microsoft\Microsoft SDKs\Windows\v10.0")
  169. folder = winreg.QueryValueEx(key, "InstallationFolder")[0]
  170. winreg.CloseKey(key)
  171. except OSError:
  172. return None, None
  173. winsdk_base_path = Path(folder)
  174. if not (winsdk_base_path / "Include").exists():
  175. return None, None
  176. version = max_version(
  177. os.listdir(winsdk_base_path / "Include"),
  178. check=partial(check_winsdk, winsdk_base_path),
  179. )
  180. if version is None:
  181. return None, None
  182. return winsdk_base_path, version
  183. def find_winsdk_hardcoded() -> tuple[Optional[Path], Optional[str]]:
  184. winsdk_base_path = find_in_program_files(r"Windows Kits\10")
  185. if winsdk_base_path is None:
  186. return None, None
  187. if not (winsdk_base_path / "Include").exists():
  188. return None, None
  189. version = max_version(
  190. os.listdir(winsdk_base_path / "Include"),
  191. check=partial(check_winsdk, winsdk_base_path),
  192. )
  193. if version is None:
  194. return None, None
  195. return winsdk_base_path, version
  196. def find_winsdk(env_only: bool) -> tuple[list[str], list[str]]:
  197. if env_only:
  198. fs = [find_winsdk_env]
  199. else:
  200. fs = [
  201. find_winsdk_env,
  202. find_winsdk_registry,
  203. find_winsdk_hardcoded,
  204. ]
  205. for f in fs:
  206. winsdk_base_path, version = f()
  207. if winsdk_base_path:
  208. return (
  209. [
  210. str(winsdk_base_path / "Include" / version / "shared"),
  211. str(winsdk_base_path / "Include" / version / "ucrt"),
  212. str(winsdk_base_path / "Include" / version / "um"),
  213. ],
  214. [
  215. str(winsdk_base_path / "Lib" / version / "ucrt" / "x64"),
  216. str(winsdk_base_path / "Lib" / version / "um" / "x64"),
  217. ],
  218. )
  219. if not env_only:
  220. warnings.warn("Failed to find Windows SDK.")
  221. return [], []
  222. @functools.lru_cache
  223. def find_msvc_winsdk(env_only: bool = False, ) -> tuple[Optional[str], list[str], list[str]]:
  224. msvc_bin_path, msvc_inc_dirs, msvc_lib_dirs = find_msvc(env_only)
  225. winsdk_inc_dirs, winsdk_lib_dirs = find_winsdk(env_only)
  226. return (
  227. msvc_bin_path,
  228. msvc_inc_dirs + winsdk_inc_dirs,
  229. msvc_lib_dirs + winsdk_lib_dirs,
  230. )
  231. @functools.lru_cache
  232. def find_python() -> list[str]:
  233. version = sysconfig.get_python_version().replace(".", "")
  234. if sysconfig.get_config_var("Py_GIL_DISABLED"):
  235. version += "t"
  236. for python_base_path in [
  237. sys.exec_prefix,
  238. sys.base_exec_prefix,
  239. os.path.dirname(sys.executable),
  240. ]:
  241. python_lib_dir = Path(python_base_path) / "libs"
  242. if (python_lib_dir / f"python{version}.lib").exists():
  243. return [str(python_lib_dir)]
  244. warnings.warn("Failed to find Python libs.")
  245. return []
  246. def check_and_find_cuda(base_path: Path) -> tuple[Optional[str], list[str], list[str]]:
  247. # pip
  248. if all(x.exists() for x in [
  249. base_path / "cuda_nvcc" / "bin" / "ptxas.exe",
  250. base_path / "cuda_runtime" / "include" / "cuda.h",
  251. base_path / "cuda_runtime" / "lib" / "x64" / "cuda.lib",
  252. ]):
  253. return (
  254. str(base_path / "cuda_nvcc" / "bin"),
  255. [str(base_path / "cuda_runtime" / "include")],
  256. [str(base_path / "cuda_runtime" / "lib" / "x64")],
  257. )
  258. # conda
  259. if all(x.exists() for x in [
  260. base_path / "bin" / "ptxas.exe",
  261. base_path / "include" / "cuda.h",
  262. base_path / "lib" / "cuda.lib",
  263. ]):
  264. return (
  265. str(base_path / "bin"),
  266. [str(base_path / "include")],
  267. [str(base_path / "lib")],
  268. )
  269. # bundled or system-wide
  270. if all(x.exists() for x in [
  271. base_path / "bin" / "ptxas.exe",
  272. base_path / "include" / "cuda.h",
  273. base_path / "lib" / "x64" / "cuda.lib",
  274. ]):
  275. return (
  276. str(base_path / "bin"),
  277. [str(base_path / "include")],
  278. [str(base_path / "lib" / "x64")],
  279. )
  280. return None, [], []
  281. def find_cuda_env() -> tuple[Optional[str], list[str], list[str]]:
  282. for cuda_base_path in ["CUDA_PATH", "CUDA_HOME"]:
  283. cuda_base_path = os.getenv(cuda_base_path)
  284. if cuda_base_path is None:
  285. continue
  286. cuda_base_path = Path(cuda_base_path)
  287. cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(cuda_base_path)
  288. if cuda_bin_path:
  289. return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
  290. return None, [], []
  291. def find_cuda_bundled() -> tuple[Optional[str], list[str], list[str]]:
  292. cuda_base_path = (Path(sysconfig.get_paths()["platlib"]) / "triton" / "backends" / "nvidia")
  293. return check_and_find_cuda(cuda_base_path)
  294. def find_cuda_pip() -> tuple[Optional[str], list[str], list[str]]:
  295. nvidia_base_path = Path(sysconfig.get_paths()["platlib"]) / "nvidia"
  296. return check_and_find_cuda(nvidia_base_path)
  297. def find_cuda_conda() -> tuple[Optional[str], list[str], list[str]]:
  298. cuda_base_path = Path(sys.exec_prefix) / "Library"
  299. return check_and_find_cuda(cuda_base_path)
  300. def find_cuda_hardcoded() -> tuple[Optional[str], list[str], list[str]]:
  301. parent = find_in_program_files(r"NVIDIA GPU Computing Toolkit\CUDA")
  302. if parent is None:
  303. return None, [], []
  304. paths = glob(str(parent / "v12*"))
  305. # First try the highest version
  306. paths = sorted(paths)[::-1]
  307. for cuda_base_path in paths:
  308. cuda_base_path = Path(cuda_base_path)
  309. cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = check_and_find_cuda(cuda_base_path)
  310. if cuda_bin_path:
  311. return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
  312. return None, [], []
  313. @functools.lru_cache
  314. def find_cuda() -> tuple[Optional[str], list[str], list[str]]:
  315. for f in [
  316. find_cuda_env,
  317. find_cuda_bundled,
  318. find_cuda_pip,
  319. find_cuda_conda,
  320. find_cuda_hardcoded,
  321. ]:
  322. cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs = f()
  323. if cuda_bin_path:
  324. return cuda_bin_path, cuda_inc_dirs, cuda_lib_dirs
  325. warnings.warn("Failed to find CUDA.")
  326. return None, [], []
  327. @functools.lru_cache
  328. def find_hip() -> tuple[Optional[str], list[str], list[str]]:
  329. """Find HIP SDK paths (bin, include dirs, lib dirs) from ROCm SDK wheels."""
  330. try:
  331. import rocm_sdk
  332. paths = rocm_sdk.find_libraries("amdhip64")
  333. if paths:
  334. bin_dir = str(paths[0].parent)
  335. root = str(paths[0].parent.parent)
  336. inc_dir = os.path.join(root, "include")
  337. lib_dir = os.path.join(root, "lib")
  338. inc_dirs = [inc_dir] if os.path.isdir(inc_dir) else []
  339. lib_dirs = [bin_dir]
  340. if os.path.isdir(lib_dir):
  341. lib_dirs.append(lib_dir)
  342. return bin_dir, inc_dirs, lib_dirs
  343. except (ImportError, ModuleNotFoundError):
  344. pass
  345. warnings.warn("Failed to find ROCm/HIP.")
  346. return None, [], []
  347. def normalize_path(path: str) -> str:
  348. r"""Normalize to absolute path with UNC prefix \\?\ so it does not suffer from 260-char length limit."""
  349. if os.name != "nt":
  350. return path
  351. path = os.path.abspath(path).replace("/", "\\")
  352. if path.startswith("\\\\?\\"):
  353. return path
  354. if path.startswith("\\\\.\\"):
  355. # Local device path such as \\.\C:\path\to\file.c
  356. # Change the prefix from \\.\ to \\?\
  357. return "\\\\?\\" + path[4:]
  358. if path.startswith("\\\\"):
  359. # Standard UNC path such as \\localhost\C$\path\to\file.c or \\ServerName\ShareName\path\to\file.c
  360. # Change the prefix from \\ to \\?\UNC\
  361. return "\\\\?\\UNC\\" + path[2:]
  362. # Now the path must be a non-UNC path
  363. return "\\\\?\\" + path
  364. kernel32 = None
  365. if os.name == "nt":
  366. from ctypes import wintypes
  367. kernel32 = ctypes.WinDLL("kernel32", use_last_error=True)
  368. kernel32.GetShortPathNameW.argtypes = [wintypes.LPCWSTR, wintypes.LPWSTR, wintypes.DWORD]
  369. kernel32.GetShortPathNameW.restype = wintypes.DWORD
  370. def get_8dot3_short_path(path: str) -> str:
  371. if os.name != "nt":
  372. return path
  373. path = normalize_path(path)
  374. req_size = kernel32.GetShortPathNameW(path, None, 0)
  375. if req_size > 0:
  376. buf = ctypes.create_unicode_buffer(req_size)
  377. if kernel32.GetShortPathNameW(path, buf, req_size) > 0:
  378. path = buf.value
  379. if path.startswith("\\\\?\\UNC\\"):
  380. return "\\\\" + path[8:]
  381. if path.startswith("\\\\?\\"):
  382. return path[4:]
  383. return path