| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562 |
- from triton.backends.compiler import BaseBackend, GPUTarget, Language
- from triton._C.libtriton import ir, passes, llvm, nvidia
- from triton import knobs
- from triton.runtime.errors import PTXASError
- from dataclasses import dataclass
- import functools
- from typing import Any, Dict, Tuple, Optional
- from types import ModuleType
- import hashlib
- import re
- import tempfile
- import signal
- import os
- import subprocess
- from pathlib import Path
- def min_dot_size(target: GPUTarget):
- def check_dot_compatibility(lhs_type, rhs_type) -> Tuple[int, int, int]: # [m, n, k]
- lhs_bitwidth = lhs_type.scalar.primitive_bitwidth
- rhs_bitwidth = rhs_type.scalar.primitive_bitwidth
- assert lhs_bitwidth == rhs_bitwidth, "lhs and rhs bitwidth must be the same"
- # For small M/N the input we can still use tensorcores with padding.
- if lhs_bitwidth == 8:
- return (1, 1, 32)
- else:
- return (1, 1, 16)
- return check_dot_compatibility
- def get_ptxas(arch: int) -> knobs.NvidiaTool:
- # return knobs.nvidia.ptxas_blackwell if arch >= 100 else knobs.nvidia.ptxas
- return knobs.nvidia.ptxas
- @functools.lru_cache()
- def get_ptxas_version(arch: int = 80):
- mock_ver = knobs.nvidia.mock_ptx_version
- if mock_ver is not None:
- return mock_ver # This is not really a version of ptxas, but it is good enough for testing
- version = subprocess.check_output([get_ptxas(arch).path, "--version"]).decode("utf-8")
- return version
- @functools.lru_cache()
- def ptx_get_version(cuda_version) -> int:
- '''
- Get the highest PTX version supported by the current CUDA driver.
- '''
- assert isinstance(cuda_version, str)
- major, minor = map(int, cuda_version.split('.'))
- if major == 12:
- if minor < 6:
- return 80 + minor
- else:
- return 80 + minor - 1
- if major == 11:
- return 70 + minor
- if major == 10:
- return 63 + minor
- if major >= 13:
- base_ptx = 90
- return base_ptx + (major - 13) * 10 + minor
- raise RuntimeError("Triton only support CUDA 10.0 or higher, but got CUDA version: " + cuda_version)
- def get_ptx_version_from_options(options, arch: int):
- ptx_version = options.ptx_version
- if ptx_version is None:
- cuda_version = get_ptxas(arch).version
- ptx_version = ptx_get_version(cuda_version)
- return ptx_version
- @functools.lru_cache()
- def get_features(options, arch: int):
- ptx_version = get_ptx_version_from_options(options, arch)
- # PTX 8.6 is the max version supported by llvm c1188642.
- #
- # To check if a newer PTX version is supported, increase this value
- # and run a test. If it's not supported, LLVM will print a warning
- # like "+ptx8.4 is not a recognized feature for this target".
- llvm_ptx_version = min(86, ptx_version)
- features = f'+ptx{llvm_ptx_version}'
- return features
- @functools.lru_cache(None)
- def file_hash(path):
- with open(path, "rb") as f:
- return hashlib.sha256(f.read()).hexdigest()
- def sm_arch_from_capability(capability: int):
- # TODO: Handle non-"a" sms
- suffix = "a" if capability >= 90 else ""
- return f"sm_{capability}{suffix}"
- # The file may be accessed in parallel
- def try_remove(path):
- if os.path.exists(path):
- try:
- os.remove(path)
- except OSError:
- import traceback
- traceback.print_exc()
- @dataclass(frozen=True)
- class CUDAOptions:
- num_warps: int = 4
- num_ctas: int = 1
- num_stages: int = 3
- warp_size: int = 32
- # maxnreg corresponds to the ptx parameter .maxnreg, which controls the
- # maximum number of 32-bit registers used by one thread.
- maxnreg: Optional[int] = None
- ptx_version: int = None
- ptx_options: Optional[str] = knobs.nvidia.ptxas_options
- ir_override: Optional[str] = None # filename of a user-defined IR (*.{ttir|ttgir|llir|ptx})
- enable_fp_fusion: bool = True
- enable_reflect_ftz: bool = True # ftz in libdevice
- launch_cooperative_grid: bool = False
- launch_pdl: bool = False
- supported_fp8_dtypes: Tuple[str] = ("fp8e4nv", "fp8e5", "fp8e4b15")
- deprecated_fp8_dot_operand_dtypes: Tuple[str] = ()
- default_dot_input_precision: str = "tf32"
- allowed_dot_input_precisions: Tuple[str] = ("tf32", "tf32x3", "ieee", 'bf16x3', 'bf16x6')
- max_num_imprecise_acc_default: bool = None
- extern_libs: dict = None
- debug: bool = False
- backend_name: str = 'cuda'
- sanitize_overflow: bool = True
- arch: str = None
- instrumentation_mode: str = ""
- def __post_init__(self):
- default_libdir = Path(__file__).parent / 'lib'
- extern_libs = {} if self.extern_libs is None else dict(self.extern_libs)
- if not extern_libs.get('libdevice', None):
- extern_libs['libdevice'] = knobs.nvidia.libdevice_path or str(default_libdir / 'libdevice.10.bc')
- object.__setattr__(self, 'extern_libs', tuple(extern_libs.items()))
- assert self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0, \
- "num_warps must be a power of 2"
- def hash(self):
- hash_dict = dict(self.__dict__)
- hash_dict["extern_libs"] = tuple((k, file_hash(v)) for k, v in sorted(hash_dict["extern_libs"]))
- key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items())])
- return hashlib.sha256(key.encode("utf-8")).hexdigest()
- class CUDABackend(BaseBackend):
- instrumentation = None
- @staticmethod
- def supports_target(target: GPUTarget):
- return target.backend == 'cuda'
- def _parse_arch(self, arch):
- pattern = r"^sm(\d+)$"
- match = re.fullmatch(pattern, arch)
- if not match:
- raise ValueError(f"TRITON_OVERRIDE_ARCH must have the form {pattern}")
- return int(match.group(1))
- def get_target_name(self, options) -> str:
- capability = self._parse_arch(options.arch)
- return f"cuda:{capability}"
- def __init__(self, target: GPUTarget) -> None:
- super().__init__(target)
- self.binary_ext = "cubin"
- def parse_options(self, opts) -> Any:
- # Enable debug mode for ConSan, so device-side assertions are not optimized out
- if "instrumentation_mode" in opts and opts["instrumentation_mode"] == "consan":
- opts["debug"] = True
- args = {'arch': knobs.runtime.override_arch or f"sm{self.target.arch}"}
- args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
- capability = int(self._parse_arch(args["arch"]))
- if args.get("num_ctas", 1) > 1 and capability < 90:
- raise ValueError((f"num_ctas > 1 requires NVIDIA SM90+ (Hopper). "
- f"Current target is sm_{capability}. This configuration will fail. "
- f"Please set num_ctas=1 or target an SM90+ GPU."))
- if "supported_fp8_dtypes" not in args:
- supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes)
- args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes))
- if "deprecated_fp8_dot_operand_dtypes" not in args:
- if capability >= 90:
- args["deprecated_fp8_dot_operand_dtypes"] = ("fp8e4b15", )
- if "enable_fp_fusion" not in args:
- args["enable_fp_fusion"] = knobs.language.default_fp_fusion
- args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0
- return CUDAOptions(**args)
- def pack_metadata(self, metadata):
- return (
- metadata.num_warps,
- metadata.num_ctas,
- metadata.shared,
- )
- def get_codegen_implementation(self, options):
- import triton.language.extra.cuda as cuda
- capability = int(self._parse_arch(options.arch))
- codegen_fns = {
- "convert_custom_types":
- cuda.convert_custom_float8_sm80 if capability >= 80 else cuda.convert_custom_float8_sm70, "min_dot_size":
- min_dot_size(self.target)
- }
- return codegen_fns
- def get_module_map(self) -> Dict[str, ModuleType]:
- from triton.language.extra.cuda import libdevice
- return {"triton.language.extra.libdevice": libdevice}
- def load_dialects(self, ctx):
- nvidia.load_dialects(ctx)
- if CUDABackend.instrumentation:
- CUDABackend.instrumentation.load_dialects(ctx)
- @staticmethod
- def make_ttir(mod, metadata, opt, capability):
- pm = ir.pass_manager(mod.context)
- pm.enable_debug()
- passes.common.add_inliner(pm)
- passes.ttir.add_rewrite_tensor_pointer(pm)
- if capability // 10 < 9:
- passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
- passes.common.add_canonicalizer(pm)
- passes.ttir.add_combine(pm)
- passes.ttir.add_reorder_broadcast(pm)
- passes.common.add_cse(pm)
- passes.common.add_symbol_dce(pm)
- passes.ttir.add_loop_unroll(pm)
- pm.run(mod, 'make_ttir')
- return mod
- @staticmethod
- def make_ttgir(mod, metadata, opt, capability):
- # Set maxnreg on all kernels, if it was provided.
- if opt.maxnreg is not None:
- mod.set_attr("ttg.maxnreg", ir.builder(mod.context).get_int32_attr(opt.maxnreg))
- pm = ir.pass_manager(mod.context)
- dump_enabled = pm.enable_debug()
- emuTF32 = (capability // 10 >= 8)
- passes.ttir.add_convert_to_ttgpuir(pm, f"cuda:{capability}", opt.num_warps, 32, opt.num_ctas)
- # optimize TTGIR
- passes.ttgpuir.add_coalesce(pm)
- passes.ttgpuir.add_f32_dot_tc(pm, emuTF32)
- # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass
- nvidia.passes.ttnvgpuir.add_plan_cta(pm)
- passes.ttgpuir.add_remove_layout_conversions(pm)
- passes.ttgpuir.add_optimize_thread_locality(pm)
- passes.ttgpuir.add_accelerate_matmul(pm)
- passes.ttgpuir.add_remove_layout_conversions(pm)
- passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
- nvidia.passes.ttnvgpuir.add_optimize_descriptor_encoding(pm)
- passes.ttir.add_loop_aware_cse(pm)
- if capability // 10 in [8, 9]:
- passes.ttgpuir.add_fuse_nested_loops(pm)
- passes.common.add_canonicalizer(pm)
- passes.ttir.add_triton_licm(pm)
- passes.common.add_canonicalizer(pm)
- passes.ttgpuir.add_combine_tensor_select_and_if(pm)
- nvidia.passes.hopper.add_hopper_warpspec(pm, opt.num_stages, dump_enabled)
- passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
- passes.ttgpuir.add_schedule_loops(pm)
- passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
- elif capability // 10 >= 10:
- passes.ttgpuir.add_fuse_nested_loops(pm)
- passes.common.add_canonicalizer(pm)
- passes.ttir.add_triton_licm(pm)
- passes.ttgpuir.add_optimize_accumulator_init(pm)
- passes.ttgpuir.add_hoist_tmem_alloc(pm, False)
- nvidia.passes.ttnvgpuir.add_promote_lhs_to_tmem(pm)
- passes.ttgpuir.add_assign_latencies(pm, opt.num_stages)
- passes.ttgpuir.add_schedule_loops(pm)
- passes.ttgpuir.add_warp_specialize(pm, opt.num_stages)
- passes.ttgpuir.add_pipeline(pm, opt.num_stages, dump_enabled)
- passes.ttgpuir.add_optimize_partition_warps(pm)
- passes.ttgpuir.add_combine_tensor_select_and_if(pm)
- # hoist again and allow hoisting out of if statements
- passes.ttgpuir.add_hoist_tmem_alloc(pm, True)
- nvidia.passes.ttnvgpuir.add_remove_tmem_tokens(pm)
- else:
- passes.ttir.add_triton_licm(pm)
- passes.common.add_canonicalizer(pm)
- passes.ttir.add_loop_aware_cse(pm)
- passes.ttgpuir.add_prefetch(pm)
- passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
- passes.ttgpuir.add_coalesce_async_copy(pm)
- nvidia.passes.ttnvgpuir.add_optimize_tmem_layouts(pm)
- if capability // 10 >= 9:
- nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
- passes.ttgpuir.add_remove_layout_conversions(pm)
- nvidia.passes.ttnvgpuir.add_interleave_tmem(pm)
- passes.ttgpuir.add_reduce_data_duplication(pm)
- passes.ttgpuir.add_reorder_instructions(pm)
- passes.ttir.add_loop_aware_cse(pm)
- passes.common.add_symbol_dce(pm)
- nvidia.passes.ttnvgpuir.add_fence_insertion(pm, capability)
- nvidia.passes.ttnvgpuir.add_lower_mma(pm)
- passes.common.add_sccp(pm)
- passes.common.add_cse(pm)
- passes.common.add_canonicalizer(pm)
- pm.run(mod, 'make_ttgir')
- metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
- return mod
- def gluon_to_ttgir(self, src, metadata, options, capability):
- mod = src
- pm = ir.pass_manager(mod.context)
- pm.enable_debug()
- passes.gluon.add_inliner(pm)
- passes.gluon.add_infer_coalesced_encodings(pm)
- passes.gluon.add_resolve_auto_encodings(pm)
- nvidia.passes.ttnvgpuir.add_tma_lowering(pm)
- passes.gluon.add_canonicalizer(pm)
- passes.common.add_sccp(pm)
- passes.ttir.add_loop_aware_cse(pm)
- passes.gluon.add_canonicalizer(pm)
- passes.ttgpuir.add_combine_tensor_select_and_if(pm)
- pm.run(mod, 'gluon_to_ttgir')
- metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
- return mod
- def make_llir(self, src, metadata, options, capability):
- ptx_version = get_ptx_version_from_options(options, self.target.arch)
- mod = src
- # TritonGPU -> LLVM-IR (MLIR)
- pm = ir.pass_manager(mod.context)
- pm.enable_debug()
- passes.ttgpuir.add_combine_tensor_select_and_if(pm)
- passes.ttgpuir.add_allocate_warp_groups(pm)
- passes.convert.add_scf_to_cf(pm)
- passes.gluon.add_inliner(pm)
- nvidia.passes.ttgpuir.add_allocate_shared_memory_nv(pm, capability, ptx_version)
- nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
- nvidia.passes.ttnvgpuir.add_check_matmul_two_cta(pm)
- if knobs.compilation.instrumentation_mode == "consan":
- # Call ConcurrencySanitizerPass here, before allocating global scratch memory but after allocating tensor and shared
- passes.ttgpuir.add_concurrency_sanitizer(pm)
- passes.ttgpuir.add_allocate_global_scratch_memory(pm)
- nvidia.passes.ttnvgpuir.add_proxy_fence_insertion(pm, capability)
- # instrumentation point here so we can override IRs above (e.g., ttir and ttgir)
- if CUDABackend.instrumentation:
- CUDABackend.instrumentation.patch("ttgpuir_to_llvmir", pm, mod.context)
- nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
- passes.common.add_canonicalizer(pm)
- passes.common.add_cse(pm)
- nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
- nvidia.passes.ttnvgpuir.add_warp_specialize_to_llvm(pm)
- passes.common.add_canonicalizer(pm)
- passes.common.add_cse(pm)
- passes.common.add_symbol_dce(pm)
- passes.convert.add_nvvm_to_llvm(pm)
- if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables:
- passes.llvmir.add_di_scope(pm)
- if CUDABackend.instrumentation:
- CUDABackend.instrumentation.patch("llvmir_to_llvm", pm, mod.context)
- pm.run(mod, 'make_llir')
- if knobs.compilation.dump_ir_extract_di_local_variables:
- # comments below on why separate it
- if not knobs.compilation.disable_line_info:
- pm = ir.pass_manager(mod.context)
- pm.enable_debug()
- passes.llvmir.add_di_scope(pm)
- pm.run(mod, 'make_llir.disable_line_info')
- # insert dbg intrinsic with several DI Attribute including source
- # var name and type info note: unknown reason for now, but this
- # pass and add_di_scope has to be run separately, otherwise if we
- # put them into previous pipline, it trigger a segmentfault without
- # any error message; could be due to a bug in mlir or pybind11
- pm = ir.pass_manager(mod.context)
- pm.enable_debug()
- passes.llvmir.add_di_local_variable(pm)
- pm.run(mod, 'make_llir.dump_ir_extract_di_local_variables')
- # LLVM-IR (MLIR) -> LLVM-IR (LLVM)
- llvm.init_targets()
- context = llvm.context()
- if knobs.compilation.enable_asan:
- raise RuntimeError(
- "Address Sanitizer Error: Address sanitizer is currently only supported on the AMD backend")
- llvm_mod = llvm.to_module(mod, context)
- proc = sm_arch_from_capability(capability)
- features = get_features(options, self.target.arch)
- triple = 'nvptx64-nvidia-cuda'
- nvidia.set_short_ptr()
- llvm.attach_datalayout(llvm_mod, triple, proc, features)
- if options.enable_reflect_ftz:
- nvidia.set_nvvm_reflect_ftz(llvm_mod)
- if options.extern_libs and nvidia.has_extern_deps(llvm_mod):
- paths = [path for (name, path) in options.extern_libs]
- llvm.link_extern_libs(llvm_mod, paths)
- llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
- # Get some metadata
- # warp-specialization mutates num_warps
- total_num_warps = src.get_int_attr("ttg.total-num-warps")
- if total_num_warps is not None:
- metadata["num_warps"] = total_num_warps
- metadata["shared"] = src.get_int_attr("ttg.shared")
- metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
- metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")
- metadata["global_scratch_align"] = src.get_int_attr("ttg.global_scratch_memory_alignment")
- metadata["profile_scratch_size"] = src.get_int_attr("ttg.profile_scratch_memory_size") or 0
- metadata["profile_scratch_align"] = src.get_int_attr("ttg.profile_scratch_memory_alignment") or 1
- ret = str(llvm_mod)
- del llvm_mod
- del context
- return ret
- def make_ptx(self, src, metadata, opt, capability):
- ptx_version = get_ptx_version_from_options(opt, self.target.arch)
- triple = 'nvptx64-nvidia-cuda'
- proc = sm_arch_from_capability(capability)
- features = get_features(opt, self.target.arch)
- flags = ["nvptx-mad-wide-opt"]
- ret = llvm.translate_to_asm(src, triple, proc, features, flags, opt.enable_fp_fusion, False)
- # Find kernel names (there should only be one)
- names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
- assert len(names) == 1
- metadata["name"] = names[0]
- # post-process
- ptx_version = f'{ptx_version//10}.{ptx_version%10}'
- ret = re.sub(r'\.version \d+\.\d+', f'.version {ptx_version}', ret, flags=re.MULTILINE)
- ret = re.sub(r'\.target sm_\d+', f'.target sm_{capability}', ret, flags=re.MULTILINE)
- if not knobs.compilation.dump_ir_extract_di_local_variables:
- # Remove the debug flag that prevents ptxas from optimizing the code
- # Note: if this flag is removed, the source var name and type info will be lost when ptx was compiled into cubin
- # and we may not be able to see them in cuda-gdb
- ret = re.sub(r",\s*debug|debug,\s*", "", ret)
- if knobs.nvidia.dump_nvptx:
- print("// -----// NVPTX Dump //----- //")
- print(ret)
- return ret
- def make_cubin(self, src, metadata, opt, capability):
- ptxas = get_ptxas(self.target.arch).path
- with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.ptx') as fsrc, \
- tempfile.NamedTemporaryFile(delete=False, mode='r', suffix='.log') as flog:
- fsrc.write(src)
- fsrc.flush()
- fbin = fsrc.name + '.o'
- debug_info = []
- if knobs.compilation.disable_line_info:
- # This option is ignored if used without -lineinfo
- debug_info += ["-lineinfo", "-suppress-debug-info"]
- elif knobs.nvidia.disable_ptxas_opt:
- # Synthesize complete debug info
- debug_info += ["-g"]
- else:
- # Only emit line info
- debug_info += ["-lineinfo"]
- fmad = [] if opt.enable_fp_fusion else ["--fmad=false"]
- arch = sm_arch_from_capability(capability)
- # Disable ptxas optimizations if requested
- disable_opt = ['--opt-level', '0'] if knobs.nvidia.disable_ptxas_opt else []
- # Accept more ptxas options if provided
- ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else []
- ptxas_cmd = [
- ptxas, *debug_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', fsrc.name,
- '-o', fbin
- ]
- try:
- # close_fds=True on Windows and False on Linux, see https://github.com/triton-lang/triton/pull/4357
- # On Windows, both stdout and stderr need to be redirected to flog
- subprocess.run(ptxas_cmd, check=True, close_fds=True if os.name == 'nt' else False, stdout=flog,
- stderr=flog)
- if knobs.nvidia.dump_ptxas_log:
- with open(flog.name) as log_file:
- print(log_file.read())
- except subprocess.CalledProcessError as e:
- with open(flog.name) as log_file:
- log = log_file.read()
- if e.returncode == 255:
- error = 'Internal Triton PTX codegen error'
- elif e.returncode == 128 + signal.SIGSEGV:
- error = '`ptxas` raised SIGSEGV'
- else:
- error = f'`ptxas` failed with error code {e.returncode}'
- error = (f"{error}\n"
- f"`ptxas` stderr:\n{log}\n"
- f'Repro command: {" ".join(ptxas_cmd)}\n')
- print(f"""
- ================================================================
- {error}
- {src}
- ================================================================
- please share the reproducer above with Triton project.
- """)
- raise PTXASError(error)
- with open(fbin, 'rb') as f:
- cubin = f.read()
- try_remove(fbin)
- # It's better to remove the temp files outside the context managers
- try_remove(fsrc.name)
- try_remove(flog.name)
- return cubin
- def add_stages(self, stages, options, language):
- capability = self._parse_arch(options.arch)
- if language == Language.TRITON:
- stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability)
- stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability)
- elif language == Language.GLUON:
- stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options, capability)
- stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability)
- stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch)
- stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch)
- if knobs.runtime.add_stages_inspection_hook is not None:
- knobs.runtime.add_stages_inspection_hook(self, stages, options, language, capability)
- @functools.lru_cache()
- def hash(self):
- version = get_ptxas_version(self.target.arch)
- return f'{version}-{self.target.arch}'
|