| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- from triton.runtime import driver
- from triton.runtime.jit import constexpr_function
- __all__ = ["current_target"]
- def current_target():
- try:
- active_driver = driver.active
- except RuntimeError:
- # If there is no active driver, return None
- return None
- return active_driver.get_current_target()
- current_target.__triton_builtin__ = True
- @constexpr_function
- def is_cuda():
- target = current_target()
- return target is not None and target.backend == "cuda"
- @constexpr_function
- def cuda_capability_geq(major, minor=0):
- """
- Determines whether we have compute capability >= (major, minor) and
- returns this as a constexpr boolean. This can be used for guarding
- inline asm implementations that require a certain compute capability.
- """
- target = current_target()
- if target is None or target.backend != "cuda":
- return False
- assert isinstance(target.arch, int)
- return target.arch >= major * 10 + minor
- @constexpr_function
- def is_hip():
- target = current_target()
- return target is not None and target.backend == "hip"
- @constexpr_function
- def is_hip_cdna3():
- target = current_target()
- return target is not None and target.arch == "gfx942"
- @constexpr_function
- def is_hip_cdna4():
- target = current_target()
- return target is not None and target.arch == "gfx950"
|