| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563 |
- from __future__ import annotations
- import functools
- import importlib
- import os
- import re
- import subprocess
- import sysconfig
- import pathlib
- import warnings
- from dataclasses import dataclass
- from contextlib import contextmanager
- from typing import cast, Any, Callable, Generator, Generic, Optional, Protocol, Type, TypeVar, TypedDict, TYPE_CHECKING, Union
- from triton._C.libtriton import getenv, getenv_bool # type: ignore
- if TYPE_CHECKING:
- from .runtime.cache import CacheManager, RemoteCacheBackend
- from .runtime.jit import JitFunctionInfo, KernelParam
- from .compiler.compiler import ASTSource, LazyDict, IRSource
- class Env:
- pass
- env = Env()
- propagate_env: bool = True
- def setenv(key: str, value: Optional[str]) -> None:
- if not propagate_env:
- return
- if value is not None:
- os.environ[key] = value
- elif key in os.environ:
- del os.environ[key]
- def toenv(val: Any) -> Union[None, tuple[Optional[str]]]:
- if val is None:
- return (None, )
- t = type(val)
- if t is bool:
- return ("1" if val else "0", )
- if t is str:
- return (val, )
- if t is int:
- return (str(val), )
- return None
- # There's an asymmetry here so that e.g. env_nvidia_tool can be specified with a
- # a string but return an NvidiaTool.
- SetType = TypeVar("SetType")
- GetType = TypeVar("GetType")
- _NOTHING = object()
- class env_base(Generic[SetType, GetType]):
- def __init__(self, key: str) -> None:
- self.key = key
- def __set_name__(self, objclass: Type[object], name: str) -> None:
- self.name = name
- def __get__(self, obj: Optional[object], objclass: Optional[Type[object]]) -> GetType:
- py_val = obj.__dict__.get(self.name, _NOTHING)
- if py_val is _NOTHING:
- return self.get()
- return self.transform(py_val)
- def get(self) -> GetType:
- raise NotImplementedError()
- def __set__(self, obj: object, value: Union[SetType, Env]) -> None:
- if isinstance(value, Env):
- obj.__dict__.pop(self.name, None)
- else:
- obj.__dict__[self.name] = value
- if env_val := toenv(value):
- setenv(self.key, env_val[0])
- def __delete__(self, obj: object) -> None:
- obj.__dict__.pop(self.name, None)
- def transform(self, val: SetType) -> GetType:
- # See comment about GetType/SetType in their definition above. Only needed
- # if GetType != SetType.
- return cast(GetType, val)
- class env_str(env_base[str, str]):
- def __init__(self, key: str, default: str):
- super().__init__(key)
- self.default = default
- def get(self) -> str:
- return getenv(self.key, self.default)
- class env_str_callable_default(env_base[str, str]):
- def __init__(self, key: str, default_factory: Callable[[], str]):
- super().__init__(key)
- self.default_factory = default_factory
- def get(self) -> str:
- env_val = getenv(self.key)
- if env_val is None:
- return self.default_factory()
- return env_val
- class env_bool(env_base[bool, bool]):
- def __init__(self, key: str, default: bool = False) -> None:
- super().__init__(key)
- self.default = default
- def get(self) -> bool:
- return getenv_bool(self.key, self.default)
- class env_int(env_base[int, int]):
- def __init__(self, key: str, default: int = 0) -> None:
- super().__init__(key)
- self.default = default
- def get(self) -> int:
- val = getenv(self.key)
- if val is None:
- return self.default
- try:
- return int(val)
- except ValueError as exc:
- raise RuntimeError(f"Unable to use {self.key}={val}: expected int") from exc
- ClassType = TypeVar("ClassType")
- class env_class(Generic[ClassType], env_base[Optional[Type[ClassType]], Optional[Type[ClassType]]]):
- def __init__(self, key: str, type: str) -> None:
- super().__init__(key)
- # We can't pass the type directly to avoid import cycles
- self.type = type
- def get(self) -> Optional[Type[ClassType]]:
- val = getenv(self.key)
- if val is None:
- return None
- comps = val.split(":", 1)
- if len(comps) != 2:
- raise RuntimeError(f"Unable to read {self.key}: '{val}' isn't of the form MODULE:CLASS")
- cls = getattr(importlib.import_module(comps[0]), comps[1])
- if not any((c.__name__ == self.type for c in cls.mro())):
- raise RuntimeError(f"Unable to use '{val}' from {self.key}: not of type '{self.type}'")
- return cast(Type[ClassType], cls)
- @dataclass
- class NvidiaTool:
- path: str
- version: str
- @staticmethod
- @functools.lru_cache
- def from_path(path: str) -> Optional[NvidiaTool]:
- try:
- result = subprocess.check_output([path, "--version"], stderr=subprocess.STDOUT)
- version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE)
- if version is None:
- return None
- return NvidiaTool(path, version.group(1))
- except (subprocess.CalledProcessError, FileNotFoundError):
- return None
- @functools.lru_cache
- def find_nvidia_tool(binary: str) -> str:
- path = os.path.join(os.path.dirname(__file__), "backends", "nvidia", "bin", binary)
- if os.access(path, os.X_OK):
- return path
- if os.name == "nt":
- from triton.windows_utils import find_cuda
- cuda_bin_path, _, _ = find_cuda()
- if cuda_bin_path:
- path = os.path.join(cuda_bin_path, binary)
- if os.access(path, os.X_OK):
- return path
- warnings.warn(f"Failed to find executable {binary}")
- return ""
- class env_nvidia_tool(env_base[str, NvidiaTool]):
- def __init__(self, binary: str) -> None:
- binary += sysconfig.get_config_var("EXE")
- self.binary = binary
- # Convert ptxas-blackwell to PTXAS_BLACKWELL, not PTXAS-BLACKWELL
- super().__init__(f"TRITON_{binary.upper().replace('-', '_')}_PATH")
- def get(self) -> NvidiaTool:
- return self.transform(getenv(self.key))
- def transform(self, path: str) -> NvidiaTool:
- default_path = find_nvidia_tool(self.binary)
- # We still add default as fallback in case the pointed binary isn't
- # accessible.
- if path is not None:
- paths = [path, default_path]
- else:
- paths = [default_path]
- for path in paths:
- if tool := NvidiaTool.from_path(path):
- return tool
- raise RuntimeError(f"Cannot find {self.binary}")
- # Separate classes so that types are correct
- class env_opt_str(env_base[Optional[str], Optional[str]]):
- def get(self) -> Optional[str]:
- return getenv(self.key)
- class env_opt_bool(env_base):
- def get(self) -> Optional[str]:
- return getenv_bool(self.key, None)
- @dataclass(frozen=True)
- class CompileTimes:
- """
- Model holding timing information for an invocation of the compiler.
- All times in microseconds.
- """
- # Duration of make_ir
- ir_initialization: int
- # Ordered mapping from lowering stage to duration spent in that stage.
- # Keyed by stage extension, e.g. ttir, ttgir
- lowering_stages: list[tuple[str, int]]
- # Duration of saving artifacts/metadata to cache
- store_results: int
- @property
- def total_lowering(self) -> int:
- return sum((stage[1] for stage in self.lowering_stages))
- @property
- def total(self) -> int:
- return self.ir_initialization + self.total_lowering + self.store_results
- class CompilationListener(Protocol):
- def __call__(self, *, src: Union[ASTSource, IRSource], metadata: dict[str, Any], metadata_group: dict[str, str],
- times: CompileTimes, cache_hit: bool) -> None:
- ...
- knobs_type = TypeVar("knobs_type", bound='base_knobs')
- class base_knobs:
- @property
- def knob_descriptors(self) -> dict[str, env_base]:
- return {
- k: v
- # data descriptors live on the class object
- for k, v in type(self).__dict__.items()
- if isinstance(v, env_base)
- }
- @property
- def knobs(self) -> dict[str, Any]:
- return {k: getattr(self, k) for k in self.knob_descriptors.keys()}
- def copy(self: knobs_type) -> knobs_type:
- res = type(self)()
- res.__dict__.update(self.__dict__)
- return res
- def reset(self: knobs_type) -> knobs_type:
- for knob in self.knob_descriptors.keys():
- delattr(self, knob)
- return self
- @contextmanager
- def scope(self) -> Generator[None, None, None]:
- try:
- initial_env = {knob.key: getenv(knob.key) for knob in self.knob_descriptors.values()}
- orig = dict(self.__dict__)
- yield
- finally:
- self.__dict__.clear()
- self.__dict__.update(orig)
- for k, v in initial_env.items():
- if v is not None:
- os.environ[k] = v
- elif k in os.environ:
- del os.environ[k]
- class BuildImpl(Protocol):
- def __call__(self, name: str, src: str, srcdir: str, library_dirs: list[str], include_dirs: list[str],
- libraries: list[str], /) -> str:
- ...
- class build_knobs(base_knobs):
- """Configuration controlling how the native compiler is invoked"""
- cc: env_opt_str = env_opt_str("CC")
- cudacrt_path: env_opt_str = env_opt_str("TRITON_CUDACRT_PATH")
- cudart_path: env_opt_str = env_opt_str("TRITON_CUDART_PATH")
- impl: Optional[BuildImpl] = None
- @property
- def backend_dirs(self) -> set[str]:
- return {path for path in (self.cudacrt_path, self.cudart_path) if path is not None}
- class redis_knobs(base_knobs):
- key_format: env_str = env_str("TRITON_REDIS_KEY_FORMAT", "triton:{key}:{filename}")
- host: env_str = env_str("TRITON_REDIS_HOST", "localhost")
- port: env_int = env_int("TRITON_REDIS_PORT", 6379)
- cache: cache_knobs
- class cache_knobs(base_knobs):
- home_dir: env_str = env_str("TRITON_HOME", os.path.expanduser("~/"))
- dump_dir = env_str_callable_default("TRITON_DUMP_DIR", lambda: cache.get_triton_dir("dump"))
- override_dir = env_str_callable_default("TRITON_OVERRIDE_DIR", lambda: cache.get_triton_dir("override"))
- dir = env_str_callable_default("TRITON_CACHE_DIR", lambda: cache.get_triton_dir("cache"))
- manager_class: env_class[CacheManager] = env_class("TRITON_CACHE_MANAGER", "CacheManager")
- remote_manager_class: env_class[RemoteCacheBackend] = env_class("TRITON_REMOTE_CACHE_BACKEND", "RemoteCacheBackend")
- def get_triton_dir(self, dirname: str) -> str:
- return os.path.join(self.home_dir, ".triton", dirname)
- class compilation_knobs(base_knobs):
- override: env_bool = env_bool("TRITON_KERNEL_OVERRIDE")
- dump_ir: env_bool = env_bool("TRITON_KERNEL_DUMP")
- dump_ir_extract_di_local_variables: env_bool = env_bool("LLVM_EXTRACT_DI_LOCAL_VARIABLES")
- store_binary_only: env_bool = env_bool("TRITON_STORE_BINARY_ONLY")
- always_compile: env_bool = env_bool("TRITON_ALWAYS_COMPILE")
- # TODO: Use enum to constrain / 'typecheck' the values
- use_ir_loc: env_opt_str = env_opt_str("USE_IR_LOC")
- enable_asan: env_bool = env_bool("TRITON_ENABLE_ASAN")
- disable_line_info: env_bool = env_bool("TRITON_DISABLE_LINE_INFO")
- front_end_debugging: env_bool = env_bool("TRITON_FRONT_END_DEBUGGING")
- allow_non_constexpr_globals: env_bool = env_bool("TRITON_ALLOW_NON_CONSTEXPR_GLOBALS")
- # Instrumentation mode is checked on every run, which is expensive.
- # We cache the value here to avoid the expensive check on every run.
- instrumentation_mode: str = env_str("TRITON_INSTRUMENTATION_MODE", "").get()
- listener: Union[CompilationListener, None] = None
- class autotuning_knobs(base_knobs):
- cache: env_bool = env_bool("TRITON_CACHE_AUTOTUNING")
- print: env_bool = env_bool("TRITON_PRINT_AUTOTUNING")
- class LaunchHook(Protocol):
- """Hook invoked before and after kernel launching
- """
- def __call__(self, metadata: LazyDict) -> None:
- ...
- class InitHandleHook(Protocol):
- """Hook invoked around kernel binary/module loading.
- module/function can be None for the *start* hook (before loading).
- """
- def __call__(
- self,
- module: Optional[object],
- function: Optional[Callable],
- name: str,
- metadata_group: dict[str, str],
- hash: str,
- ) -> None:
- ...
- F = TypeVar("F", bound=Callable)
- class HookChain(Generic[F]):
- """A chain of hooks of the same type F to be called in order.
- """
- def __init__(self, reversed: bool = False):
- self.calls: list[F] = []
- self.reversed = reversed
- def add(self, func: F) -> None:
- if func not in self.calls:
- self.calls.append(func)
- def remove(self, func: F) -> None:
- if func in self.calls:
- self.calls.remove(func)
- def __call__(self, *args, **kwargs):
- for call in self.calls if not self.reversed else reversed(self.calls):
- call(*args, **kwargs)
- # This is of the form [attr_name, attr_val]
- # TODO: Use tuple instead of list for better typing.
- KernelAttr = list[Union[str, int]]
- class JITHookCompileInfo(TypedDict):
- key: str
- signature: dict[KernelParam, str]
- device: int
- constants: None
- num_warps: int
- num_ctas: int
- num_stages: int
- enable_fp_fusion: bool
- launch_cooperative_grid: bool
- extern_libs: tuple[tuple[str, str], ...]
- configs: list[dict[tuple[int, ...], list[KernelAttr]]]
- specialization_data: str
- is_warmup: bool
- class JITHook(Protocol):
- def __call__(self, *, key: str, repr: str, fn: JitFunctionInfo, compile: JITHookCompileInfo, is_manual_warmup: bool,
- already_compiled: bool) -> Optional[bool]:
- ...
- class PipelineStagesHook(Protocol):
- def __call__(self, stages, options, language, capability):
- ...
- class runtime_knobs(base_knobs):
- interpret: env_bool = env_bool("TRITON_INTERPRET")
- # debug is on critical path for kernel launches
- # avoid repeated reads from env-var by calling get directly
- debug: bool = env_bool("TRITON_DEBUG").get()
- override_arch: env_opt_str = env_opt_str("TRITON_OVERRIDE_ARCH")
- launch_enter_hook: HookChain[LaunchHook] = HookChain()
- launch_exit_hook: HookChain[LaunchHook] = HookChain(reversed=True)
- kernel_load_start_hook: HookChain[InitHandleHook] = HookChain()
- kernel_load_end_hook: HookChain[InitHandleHook] = HookChain(reversed=True)
- # Hook for inspecting compiled functions and modules
- jit_cache_hook: Optional[JITHook] = None
- # Hook to signal that a kernel is done compiling and inspect compiled function.
- # jit_cache_hook will always be called before compilation and jit_post_compile_hook after.
- jit_post_compile_hook: Optional[JITHook] = None
- # Hook for inspecting compiler pipeline stages
- add_stages_inspection_hook: Optional[PipelineStagesHook] = None
- class language_knobs(base_knobs):
- fp32_default: env_opt_str = env_opt_str("TRITON_F32_DEFAULT")
- default_fp_fusion: env_bool = env_bool("TRITON_DEFAULT_FP_FUSION", True)
- class nvidia_knobs(base_knobs):
- cuobjdump: env_nvidia_tool = env_nvidia_tool("cuobjdump")
- nvdisasm: env_nvidia_tool = env_nvidia_tool("nvdisasm")
- ptxas: env_nvidia_tool = env_nvidia_tool("ptxas")
- ptxas_blackwell: env_nvidia_tool = env_nvidia_tool("ptxas-blackwell")
- dump_nvptx: env_bool = env_bool("NVPTX_ENABLE_DUMP")
- disable_ptxas_opt: env_bool = env_bool("DISABLE_PTXAS_OPT")
- ptxas_options: env_opt_str = env_opt_str("PTXAS_OPTIONS")
- mock_ptx_version: env_opt_str = env_opt_str("TRITON_MOCK_PTX_VERSION")
- dump_ptxas_log: env_bool = env_bool("TRITON_DUMP_PTXAS_LOG")
- libdevice_path: env_opt_str = env_opt_str("TRITON_LIBDEVICE_PATH")
- libcuda_path: env_opt_str = env_opt_str("TRITON_LIBCUDA_PATH")
- class amd_knobs(base_knobs):
- use_buffer_ops: env_bool = env_bool("AMDGCN_USE_BUFFER_OPS", True)
- # Note: This requires use_buffer_ops be true to have any effect
- use_buffer_atomics: env_bool = env_bool("AMDGCN_USE_BUFFER_ATOMICS", True)
- # Note: This requires use_buffer_ops be true to have any effect
- buffer_ops_analyze_small_tensor_range: env_bool = env_bool("AMDGCN_ANALYZE_SMALL_TENSOR_RANGE", False)
- dump_amdgcn: env_bool = env_bool("AMDGCN_ENABLE_DUMP")
- libhip_path: env_opt_str = env_opt_str("TRITON_LIBHIP_PATH")
- # We use strs so that we can have a default value based on other runtime info
- use_block_pingpong: env_opt_bool = env_opt_bool("TRITON_HIP_USE_BLOCK_PINGPONG")
- use_in_thread_transpose: env_opt_bool = env_opt_bool("TRITON_HIP_USE_IN_THREAD_TRANSPOSE")
- use_async_copy: env_bool = env_bool("TRITON_HIP_USE_ASYNC_COPY")
- scalarize_packed_fops: env_bool = env_bool("AMDGCN_SCALARIZE_PACKED_FOPS")
- class proton_knobs(base_knobs):
- disable: env_bool = env_bool("TRITON_PROTON_DISABLE", False)
- cupti_lib_dir: env_str = env_str(
- "TRITON_CUPTI_LIB_PATH",
- str(pathlib.Path(__file__).parent.absolute() / "backends" / "nvidia" / "lib" / "cupti"))
- enable_nvtx: env_bool = env_bool("TRITON_ENABLE_NVTX", True)
- build = build_knobs()
- redis = redis_knobs()
- cache = cache_knobs()
- compilation = compilation_knobs()
- autotuning = autotuning_knobs()
- runtime = runtime_knobs()
- language = language_knobs()
- nvidia = nvidia_knobs()
- amd = amd_knobs()
- proton = proton_knobs()
- def refresh_knobs():
- runtime.debug = env_bool("TRITON_DEBUG").get()
- compilation.instrumentation_mode = env_str("TRITON_INSTRUMENTATION_MODE", "").get()
|