target_info.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from triton.runtime import driver
  2. from triton.runtime.jit import constexpr_function
  3. __all__ = ["current_target"]
  4. def current_target():
  5. try:
  6. active_driver = driver.active
  7. except RuntimeError:
  8. # If there is no active driver, return None
  9. return None
  10. return active_driver.get_current_target()
  11. current_target.__triton_builtin__ = True
  12. @constexpr_function
  13. def is_cuda():
  14. target = current_target()
  15. return target is not None and target.backend == "cuda"
  16. @constexpr_function
  17. def cuda_capability_geq(major, minor=0):
  18. """
  19. Determines whether we have compute capability >= (major, minor) and
  20. returns this as a constexpr boolean. This can be used for guarding
  21. inline asm implementations that require a certain compute capability.
  22. """
  23. target = current_target()
  24. if target is None or target.backend != "cuda":
  25. return False
  26. assert isinstance(target.arch, int)
  27. return target.arch >= major * 10 + minor
  28. @constexpr_function
  29. def is_hip():
  30. target = current_target()
  31. return target is not None and target.backend == "hip"
  32. @constexpr_function
  33. def is_hip_cdna3():
  34. target = current_target()
  35. return target is not None and target.arch == "gfx942"
  36. @constexpr_function
  37. def is_hip_cdna4():
  38. target = current_target()
  39. return target is not None and target.arch == "gfx950"