compiler.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562
  1. from triton.backends.compiler import BaseBackend, GPUTarget, Language
  2. from triton._C.libtriton import ir, passes, llvm, nvidia
  3. from triton import knobs
  4. from triton.runtime.errors import PTXASError
  5. from dataclasses import dataclass
  6. import functools
  7. from typing import Any, Dict, Tuple, Optional
  8. from types import ModuleType
  9. import hashlib
  10. import re
  11. import tempfile
  12. import signal
  13. import os
  14. import subprocess
  15. from pathlib import Path
  16. def min_dot_size(target: GPUTarget):
  17. def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, n, k]
  18. lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
  19. rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
  20. assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
  21. # For small M/N the input we can still use tensorcores with padding.
  22. if lhs_bitwidth == 8:
  23. return (1, 1, 32)
  24. else:
  25. return (1, 1, 16)
  26. return check_dot_compatibility
  27. def get_ptxas(arch: int) -> knobs.NvidiaTool:
  28. # return knobs.nvidia.ptxas_blackwell if arch >= 100 else knobs.nvidia.ptxas
  29. return knobs.nvidia.ptxas
  30. @functools.lru_cache()
  31. def get_ptxas_version(arch: int = 80):
  32. mock_ver = knobs.nvidia.mock_ptx_version
  33. if mock_ver is not None:
  34. return mock_ver # This is not really a version of ptxas, but it is good enough for testing
  35. version = subprocess.check_output([get_ptxas(arch).path, "--version"]).decode("utf-8")
  36. return version
  37. @functools.lru_cache()
  38. def ptx_get_version(cuda_version) -> int:
  39. '''
  40. Get the highest PTX version supported by the current CUDA driver.
  41. '''
  42. assert isinstance(cuda_version, str)
  43. major, minor = map(int, cuda_version.split('.'))
  44. if major == 12:
  45. if minor < 6:
  46. return 80 + minor
  47. else:
  48. return 80 + minor - 1
  49. if major == 11:
  50. return 70 + minor
  51. if major == 10:
  52. return 63 + minor
  53. if major >= 13:
  54. base_ptx = 90
  55. return base_ptx + (major - 13) * 10 + minor
  56. raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
  57. def get_ptx_version_from_options(options, arch: int):
  58. ptx_version = options.ptx_version
  59. if ptx_version is None:
  60. cuda_version = get_ptxas(arch).version
  61. ptx_version = ptx_get_version(cuda_version)
  62. return ptx_version
  63. @functools.lru_cache()
  64. def get_features(options, arch: int):
  65. ptx_version = get_ptx_version_from_options(options, arch)
  66. # PTX 8.6 is the max version supported by llvm c1188642.
  67. #
  68. # To check if a newer PTX version is supported, increase this value
  69. # and run a test. If it's not supported, LLVM will print a warning
  70. # like "+ptx8.4 is not a recognized feature for this target".
  71. llvm_ptx_version = min(86, ptx_version)
  72. features = f'+ptx{llvm_ptx_version}'
  73. return features
  74. @functools.lru_cache(None)
  75. def file_hash(path):
  76. with open(path, "rb") as f:
  77. return hashlib.sha256(f.read()).hexdigest()
  78. def sm_arch_from_capability(capability: int):
  79. # TODO: Handle non-"a" sms
  80. suffix = "a" if capability >= 90 else ""
  81. return f"sm_{capability}{suffix}"
  82. # The file may be accessed in parallel
  83. def try_remove(path):
  84. if os.path.exists(path):
  85. try:
  86. os.remove(path)
  87. except OSError:
  88. import traceback
  89. traceback.print_exc()
  90. @dataclass(frozen=True)
  91. class CUDAOptions:
  92. num_warps: int = 4
  93. num_ctas: int = 1
  94. num_stages: int = 3
  95. warp_size: int = 32
  96. # maxnreg corresponds to the ptx parameter .maxnreg, which controls the
  97. # maximum number of 32-bit registers used by one thread.
  98. maxnreg: Optional[int] = None
  99. ptx_version: int = None
  100. ptx_options: Optional[str] = knobs.nvidia.ptxas_options
  101. ir_override: Optional[str] = None # filename of a user-defined IR (*.{ttir|ttgir|llir|ptx})
  102. enable_fp_fusion: bool = True
  103. enable_reflect_ftz: bool = True # ftz in libdevice
  104. launch_cooperative_grid: bool = False
  105. launch_pdl: bool = False
  106. supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e4b15")
  107. deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
  108. default_dot_input_precision: str = "tf32"
  109. allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee", 'bf16x3', 'bf16x6')
  110. max_num_imprecise_acc_default: bool = None
  111. extern_libs: dict = None
  112. debug: bool = False
  113. backend_name: str = 'cuda'
  114. sanitize_overflow: bool = True
  115. arch: str = None
  116. instrumentation_mode: str = ""
  117. def __post_init__(self):
  118. default_libdir = Path(__file__).parent / 'lib'
  119. extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
  120. if not extern_libs.get('libdevice', None):
  121. extern_libs['libdevice'] = knobs.nvidia.libdevice_path or str(default_libdir / 'libdevice.10.bc')
  122. object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
  123. assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
  124. "num_warps must be a power of 2"
  125. def hash(self):
  126. hash_dict = dict(self.__dict__)
  127. hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"]))
  128. key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())])
  129. return hashlib.sha256(key.encode("utf-8")).hexdigest()
  130. class CUDABackend(BaseBackend):
  131. instrumentation = None
  132. @staticmethod
  133. def supports_target(target: GPUTarget):
  134. return target.backend == 'cuda'
  135. def _parse_arch(self, arch):
  136. pattern = r"^sm(\d+)$"
  137. match = re.fullmatch(pattern, arch)
  138. if not match:
  139. raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}")
  140. return int(match.group(1))
  141. def get_target_name(self, options) -> str:
  142. capability = self._parse_arch(options.arch)
  143. return f"cuda:{capability}"
  144. def __init__(self, target: GPUTarget) -> None:
  145. super().__init__(target)
  146. self.binary_ext = "cubin"
  147. def parse_options(self, opts) -> Any:
  148. # Enable debug mode for ConSan, so device-side assertions are not optimized out
  149. if "instrumentation_mode" in opts and opts["instrumentation_mode"] == "consan":
  150. opts["debug"] = True
  151. args = {'arch': knobs.runtime.override_arch or f"sm{self.target.arch}"}
  152. args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
  153. capability = int(self._parse_arch(args["arch"]))
  154. if args.get("num_ctas", 1) > 1 and capability < 90:
  155. raise ValueError((f"num_ctas > 1 requires NVIDIA SM90+ (Hopper). "
  156. f"Current target is sm_{capability}. This configuration will fail. "
  157. f"Please set num_ctas=1 or target an SM90+ GPU."))
  158. if "supported_fp8_dtypes" not in args:
  159. supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
  160. args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
  161. if "deprecated_fp8_dot_operand_dtypes" not in args:
  162. if capability >= 90:
  163. args["deprecated_fp8_dot_operand_dtypes"] = ("fp8e4b15", )
  164. if "enable_fp_fusion" not in args:
  165. args["enable_fp_fusion"] = knobs.language.default_fp_fusion
  166. args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0
  167. return CUDAOptions(**args)
  168. def pack_metadata(self, metadata):
  169. return (
  170. metadata.num_warps,
  171. metadata.num_ctas,
  172. metadata.shared,
  173. )
  174. def get_codegen_implementation(self, options):
  175. import triton.language.extra.cuda as cuda
  176. capability = int(self._parse_arch(options.arch))
  177. codegen_fns = {
  178. "convert_custom_types":
  179. cuda.convert_custom_float8_sm80 if capability >= 80 else cuda.convert_custom_float8_sm70, "min_dot_size":
  180. min_dot_size(self.target)
  181. }
  182. return codegen_fns
  183. def get_module_map(self) -> Dict[str, ModuleType]:
  184. from triton.language.extra.cuda import libdevice
  185. return {"triton.language.extra.libdevice": libdevice}
  186. def load_dialects(self, ctx):
  187. nvidia.load_dialects(ctx)
  188. if CUDABackend.instrumentation:
  189. CUDABackend.instrumentation.load_dialects(ctx)
  190. @staticmethod
  191. def make_ttir(mod, metadata, opt, capability):
  192. pm = ir.pass_manager(mod.context)
  193. pm.enable_debug()
  194. passes.common.add_inliner(pm)
  195. passes.ttir.add_rewrite_tensor_pointer(pm)
  196. if capability // 10 < 9:
  197. passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
  198. passes.common.add_canonicalizer(pm)
  199. passes.ttir.add_combine(pm)
  200. passes.ttir.add_reorder_broadcast(pm)
  201. passes.common.add_cse(pm)
  202. passes.common.add_symbol_dce(pm)
  203. passes.ttir.add_loop_unroll(pm)
  204. pm.run(mod, 'make_ttir')
  205. return mod
  206. @staticmethod
  207. def make_ttgir(mod, metadata, opt, capability):
  208. # Set maxnreg on all kernels, if it was provided.
  209. if opt.maxnreg is not None:
  210. mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))
  211. pm = ir.pass_manager(mod.context)
  212. dump_enabled = pm.enable_debug()
  213. emuTF32 = (capability // 10 >= 8)
  214. passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
  215. # optimize TTGIR
  216. passes.ttgpuir.add_coalesce(pm)
  217. passes.ttgpuir.add_f32_dot_tc(pm, emuTF32)
  218. # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
  219. nvidia.passes.ttnvgpuir.add_plan_cta(pm)
  220. passes.ttgpuir.add_remove_layout_conversions(pm)
  221. passes.ttgpuir.add_optimize_thread_locality(pm)
  222. passes.ttgpuir.add_accelerate_matmul(pm)
  223. passes.ttgpuir.add_remove_layout_conversions(pm)
  224. passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
  225. nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
  226. passes.ttir.add_loop_aware_cse(pm)
  227. if capability // 10 in [8, 9]:
  228. passes.ttgpuir.add_fuse_nested_loops(pm)
  229. passes.common.add_canonicalizer(pm)
  230. passes.ttir.add_triton_licm(pm)
  231. passes.common.add_canonicalizer(pm)
  232. passes.ttgpuir.add_combine_tensor_select_and_if(pm)
  233. nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
  234. passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
  235. passes.ttgpuir.add_schedule_loops(pm)
  236. passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
  237. elif capability // 10 >= 10:
  238. passes.ttgpuir.add_fuse_nested_loops(pm)
  239. passes.common.add_canonicalizer(pm)
  240. passes.ttir.add_triton_licm(pm)
  241. passes.ttgpuir.add_optimize_accumulator_init(pm)
  242. passes.ttgpuir.add_hoist_tmem_alloc(pm, False)
  243. nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
  244. passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
  245. passes.ttgpuir.add_schedule_loops(pm)
  246. passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
  247. passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
  248. passes.ttgpuir.add_optimize_partition_warps(pm)
  249. passes.ttgpuir.add_combine_tensor_select_and_if(pm)
  250. # hoist again and allow hoisting out of if statements
  251. passes.ttgpuir.add_hoist_tmem_alloc(pm, True)
  252. nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
  253. else:
  254. passes.ttir.add_triton_licm(pm)
  255. passes.common.add_canonicalizer(pm)
  256. passes.ttir.add_loop_aware_cse(pm)
  257. passes.ttgpuir.add_prefetch(pm)
  258. passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
  259. passes.ttgpuir.add_coalesce_async_copy(pm)
  260. nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
  261. if capability // 10 >= 9:
  262. nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
  263. passes.ttgpuir.add_remove_layout_conversions(pm)
  264. nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
  265. passes.ttgpuir.add_reduce_data_duplication(pm)
  266. passes.ttgpuir.add_reorder_instructions(pm)
  267. passes.ttir.add_loop_aware_cse(pm)
  268. passes.common.add_symbol_dce(pm)
  269. nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
  270. nvidia.passes.ttnvgpuir.add_lower_mma(pm)
  271. passes.common.add_sccp(pm)
  272. passes.common.add_cse(pm)
  273. passes.common.add_canonicalizer(pm)
  274. pm.run(mod, 'make_ttgir')
  275. metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
  276. return mod
  277. def gluon_to_ttgir(self, src, metadata, options, capability):
  278. mod = src
  279. pm = ir.pass_manager(mod.context)
  280. pm.enable_debug()
  281. passes.gluon.add_inliner(pm)
  282. passes.gluon.add_infer_coalesced_encodings(pm)
  283. passes.gluon.add_resolve_auto_encodings(pm)
  284. nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
  285. passes.gluon.add_canonicalizer(pm)
  286. passes.common.add_sccp(pm)
  287. passes.ttir.add_loop_aware_cse(pm)
  288. passes.gluon.add_canonicalizer(pm)
  289. passes.ttgpuir.add_combine_tensor_select_and_if(pm)
  290. pm.run(mod, 'gluon_to_ttgir')
  291. metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
  292. return mod
  293. def make_llir(self, src, metadata, options, capability):
  294. ptx_version = get_ptx_version_from_options(options, self.target.arch)
  295. mod = src
  296. # TritonGPU -> LLVM-IR (MLIR)
  297. pm = ir.pass_manager(mod.context)
  298. pm.enable_debug()
  299. passes.ttgpuir.add_combine_tensor_select_and_if(pm)
  300. passes.ttgpuir.add_allocate_warp_groups(pm)
  301. passes.convert.add_scf_to_cf(pm)
  302. passes.gluon.add_inliner(pm)
  303. nvidia.passes.ttgpuir.add_allocate_shared_memory_nv(pm, capability, ptx_version)
  304. nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
  305. nvidia.passes.ttnvgpuir.add_check_matmul_two_cta(pm)
  306. if knobs.compilation.instrumentation_mode == "consan":
  307. # Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
  308. passes.ttgpuir.add_concurrency_sanitizer(pm)
  309. passes.ttgpuir.add_allocate_global_scratch_memory(pm)
  310. nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
  311. # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
  312. if CUDABackend.instrumentation:
  313. CUDABackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
  314. nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
  315. passes.common.add_canonicalizer(pm)
  316. passes.common.add_cse(pm)
  317. nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
  318. nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
  319. passes.common.add_canonicalizer(pm)
  320. passes.common.add_cse(pm)
  321. passes.common.add_symbol_dce(pm)
  322. passes.convert.add_nvvm_to_llvm(pm)
  323. if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables:
  324. passes.llvmir.add_di_scope(pm)
  325. if CUDABackend.instrumentation:
  326. CUDABackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
  327. pm.run(mod, 'make_llir')
  328. if knobs.compilation.dump_ir_extract_di_local_variables:
  329. # comments below on why separate it
  330. if not knobs.compilation.disable_line_info:
  331. pm = ir.pass_manager(mod.context)
  332. pm.enable_debug()
  333. passes.llvmir.add_di_scope(pm)
  334. pm.run(mod, 'make_llir.disable_line_info')
  335. # insert dbg intrinsic with several DI Attribute including source
  336. # var name and type info note: unknown reason for now, but this
  337. # pass and add_di_scope has to be run separately, otherwise if we
  338. # put them into previous pipline, it trigger a segmentfault without
  339. # any error message; could be due to a bug in mlir or pybind11
  340. pm = ir.pass_manager(mod.context)
  341. pm.enable_debug()
  342. passes.llvmir.add_di_local_variable(pm)
  343. pm.run(mod, 'make_llir.dump_ir_extract_di_local_variables')
  344. # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
  345. llvm.init_targets()
  346. context = llvm.context()
  347. if knobs.compilation.enable_asan:
  348. raise RuntimeError(
  349. "Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
  350. llvm_mod = llvm.to_module(mod, context)
  351. proc = sm_arch_from_capability(capability)
  352. features = get_features(options, self.target.arch)
  353. triple = 'nvptx64-nvidia-cuda'
  354. nvidia.set_short_ptr()
  355. llvm.attach_datalayout(llvm_mod, triple, proc, features)
  356. if options.enable_reflect_ftz:
  357. nvidia.set_nvvm_reflect_ftz(llvm_mod)
  358. if options.extern_libs and nvidia.has_extern_deps(llvm_mod):
  359. paths = [path for (name, path) in options.extern_libs]
  360. llvm.link_extern_libs(llvm_mod, paths)
  361. llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
  362. # Get some metadata
  363. # warp-specialization mutates num_warps
  364. total_num_warps = src.get_int_attr("ttg.total-num-warps")
  365. if total_num_warps is not None:
  366. metadata["num_warps"] = total_num_warps
  367. metadata["shared"] = src.get_int_attr("ttg.shared")
  368. metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
  369. metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
  370. metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
  371. metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
  372. metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
  373. ret = str(llvm_mod)
  374. del llvm_mod
  375. del context
  376. return ret
  377. def make_ptx(self, src, metadata, opt, capability):
  378. ptx_version = get_ptx_version_from_options(opt, self.target.arch)
  379. triple = 'nvptx64-nvidia-cuda'
  380. proc = sm_arch_from_capability(capability)
  381. features = get_features(opt, self.target.arch)
  382. flags = ["nvptx-mad-wide-opt"]
  383. ret = llvm.translate_to_asm(src, triple, proc, features, flags, opt.enable_fp_fusion, False)
  384. # Find kernel names (there should only be one)
  385. names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
  386. assert len(names) == 1
  387. metadata["name"] = names[0]
  388. # post-process
  389. ptx_version = f'{ptx_version//10}.{ptx_version%10}'
  390. ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
  391. ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
  392. if not knobs.compilation.dump_ir_extract_di_local_variables:
  393. # Remove the debug flag that prevents ptxas from optimizing the code
  394. # Note: if this flag is removed, the source var name and type info will be lost when ptx was compiled into cubin
  395. # and we may not be able to see them in cuda-gdb
  396. ret = re.sub(r",\s*debug|debug,\s*", "", ret)
  397. if knobs.nvidia.dump_nvptx:
  398. print("// -----// NVPTX Dump //----- //")
  399. print(ret)
  400. return ret
  401. def make_cubin(self, src, metadata, opt, capability):
  402. ptxas = get_ptxas(self.target.arch).path
  403. with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
  404. tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
  405. fsrc.write(src)
  406. fsrc.flush()
  407. fbin = fsrc.name + '.o'
  408. debug_info = []
  409. if knobs.compilation.disable_line_info:
  410. # This option is ignored if used without -lineinfo
  411. debug_info += ["-lineinfo", "-suppress-debug-info"]
  412. elif knobs.nvidia.disable_ptxas_opt:
  413. # Synthesize complete debug info
  414. debug_info += ["-g"]
  415. else:
  416. # Only emit line info
  417. debug_info += ["-lineinfo"]
  418. fmad = [] if opt.enable_fp_fusion else ["--fmad=false"]
  419. arch = sm_arch_from_capability(capability)
  420. # Disable ptxas optimizations if requested
  421. disable_opt = ['--opt-level', '0'] if knobs.nvidia.disable_ptxas_opt else []
  422. # Accept more ptxas options if provided
  423. ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else []
  424. ptxas_cmd = [
  425. ptxas, *debug_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', fsrc.name,
  426. '-o', fbin
  427. ]
  428. try:
  429. # close_fds=True on Windows and False on Linux, see https://github.com/triton-lang/triton/pull/4357
  430. # On Windows, both stdout and stderr need to be redirected to flog
  431. subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog,
  432. stderr=flog)
  433. if knobs.nvidia.dump_ptxas_log:
  434. with open(flog.name) as log_file:
  435. print(log_file.read())
  436. except subprocess.CalledProcessError as e:
  437. with open(flog.name) as log_file:
  438. log = log_file.read()
  439. if e.returncode == 255:
  440. error = 'Internal Triton PTX codegen error'
  441. elif e.returncode == 128 + signal.SIGSEGV:
  442. error = '`ptxas` raised SIGSEGV'
  443. else:
  444. error = f'`ptxas` failed with error code {e.returncode}'
  445. error = (f"{error}\n"
  446. f"`ptxas` stderr:\n{log}\n"
  447. f'Repro command: {" ".join(ptxas_cmd)}\n')
  448. print(f"""
  449. ================================================================
  450. {error}
  451. {src}
  452. ================================================================
  453. please share the reproducer above with Triton project.
  454. """)
  455. raise PTXASError(error)
  456. with open(fbin, 'rb') as f:
  457. cubin = f.read()
  458. try_remove(fbin)
  459. # It's better to remove the temp files outside the context managers
  460. try_remove(fsrc.name)
  461. try_remove(flog.name)
  462. return cubin
  463. def add_stages(self, stages, options, language):
  464. capability = self._parse_arch(options.arch)
  465. if language == Language.TRITON:
  466. stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability)
  467. stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
  468. elif language == Language.GLUON:
  469. stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options, capability)
  470. stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
  471. stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
  472. stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
  473. if knobs.runtime.add_stages_inspection_hook is not None:
  474. knobs.runtime.add_stages_inspection_hook(self, stages, options, language, capability)
  475. @functools.lru_cache()
  476. def hash(self):
  477. version = get_ptxas_version(self.target.arch)
  478. return f'{version}-{self.target.arch}'