compiler.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. from __future__ import annotations
  2. import hashlib
  3. import json
  4. from .._C.libtriton import get_cache_invalidating_env_vars, ir
  5. from ..backends import backends
  6. from ..backends.compiler import Language
  7. from ..backends.compiler import BaseBackend, GPUTarget
  8. from .. import __version__, knobs
  9. from ..runtime.autotuner import OutOfResources
  10. from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager, get_cache_key
  11. from ..runtime.driver import driver
  12. from ..tools.disasm import get_sass
  13. from pathlib import Path
  14. import re
  15. import functools
  16. import os
  17. import time
  18. import copy
  19. # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
  20. # and any following whitespace
  21. # - (public\s+)? : optionally match the keyword public and any following whitespace
  22. # - (@\w+) : match an @ symbol followed by one or more word characters
  23. # (letters, digits, or underscores), and capture it as group 1 (the function name)
  24. # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing
  25. # zero or more arguments separated by commas, and capture it as group 2 (the argument list)
  26. # - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3
  27. ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)"
  28. prototype_pattern = {
  29. "ptx": ptx_prototype_pattern,
  30. }
  31. ptx_arg_type_pattern = r"\.param\s+\.(\w+)"
  32. arg_type_pattern = {
  33. "ptx": ptx_arg_type_pattern,
  34. }
  35. def convert_type_repr(x):
  36. # Currently we only capture the pointer type and assume the pointer is on global memory.
  37. # TODO: Capture and support shared memory space
  38. match = re.search(r'!tt\.ptr<([^,]+)', x)
  39. tma = re.search(r'tt.nv_tma_desc = 1', x)
  40. if tma is not None:
  41. return 'nvTmaDesc'
  42. x = re.sub(r' {[^}]+}', '', x)
  43. if match is not None:
  44. return '*' + convert_type_repr(match.group(1))
  45. return x
  46. class ASTSource:
  47. def __init__(self, fn, signature, constexprs=None, attrs=None) -> None:
  48. self.fn = fn
  49. self.language = Language.TRITON
  50. self.ext = "ttir"
  51. self.name = fn.__name__
  52. self.signature = signature
  53. self.constants = dict()
  54. if constexprs is not None:
  55. for k, v in constexprs.items():
  56. k = (fn.arg_names.index(k), ) if isinstance(k, str) else k
  57. assert isinstance(k, tuple)
  58. self.constants[k] = v
  59. self.attrs = attrs or dict()
  60. for k in self.signature.keys():
  61. if not isinstance(k, str):
  62. raise TypeError("Signature keys must be string")
  63. def hash(self):
  64. sorted_sig = [v for k, v in sorted(self.signature.items())]
  65. get_key = lambda x: x.cache_key if hasattr(x, 'cache_key') else str(x)
  66. constants_key = '-'.join([get_key(v) for k, v in sorted(self.constants.items())])
  67. key = f"{self.fn.cache_key}-{str(self.attrs)}-{sorted_sig}-{constants_key}"
  68. return hashlib.sha256(key.encode("utf-8")).hexdigest()
  69. def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
  70. from .code_generator import ast_to_ttir
  71. return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns,
  72. module_map=module_map)
  73. def parse_options(self):
  74. return dict()
  75. class IRSource:
  76. def __init__(self, path, context, backend):
  77. self.path = path
  78. path = Path(path)
  79. self.ext = path.suffix[1:]
  80. self.language = Language.TRITON
  81. self.src = path.read_text()
  82. ir.load_dialects(context)
  83. backend.load_dialects(context)
  84. # We don't have a easy-to-use PTX parser that we can use, so keep that regex for now.
  85. # TODO - replace with a proper parser
  86. if self.ext == "ptx":
  87. match = re.search(prototype_pattern[self.ext], self.src, re.MULTILINE)
  88. self.name = match.group(1)
  89. signature = match.group(2)
  90. types = re.findall(arg_type_pattern[self.ext], signature)
  91. self.signature = {k: convert_type_repr(ty) for k, ty in enumerate(types)}
  92. else:
  93. self.module = ir.parse_mlir_module(self.path, context)
  94. fn_name = self.module.get_entry_func_name()
  95. self.name = "@" + fn_name
  96. funcOp = self.module.get_function(fn_name)
  97. func_ty = self.module.get_function_signature(funcOp)
  98. self.signature = {k: ty for k, ty in enumerate(func_ty)}
  99. def hash(self):
  100. return hashlib.sha256(self.src.encode("utf-8")).hexdigest()
  101. def make_ir(self, target: GPUTarget, options, codegen_fns, module_map, context):
  102. self.module.context = context
  103. return self.module
  104. def parse_options(self):
  105. if self.ext == "ttgir":
  106. num_warps = self.module.get_int_attr("ttg.num-warps")
  107. assert num_warps is not None, "Unable to parse ttg.num-warps attribute"
  108. options = {'num_warps': num_warps}
  109. num_ctas = self.module.get_int_attr("ttg.num-ctas")
  110. if num_ctas is not None:
  111. options['num_ctas'] = num_ctas
  112. return options
  113. return dict()
  114. @functools.lru_cache()
  115. def max_shared_mem(device):
  116. return driver.active.utils.get_device_properties(device)["max_shared_mem"]
  117. def parse(full_name, ext, context):
  118. if ext == "ttir" or ext == "ttgir":
  119. module = ir.parse_mlir_module(full_name, context)
  120. module.context = context
  121. return module
  122. if ext == "llir" or ext == "ptx" or ext == "amdgcn":
  123. return Path(full_name).read_text()
  124. if ext == "cubin" or ext == "hsaco":
  125. return Path(full_name).read_bytes()
  126. def filter_traceback(e: BaseException):
  127. """
  128. Removes code_generator.py and related files from tracebacks.
  129. These are uninteresting to the user -- "just show me *my* code!"
  130. """
  131. if knobs.compilation.front_end_debugging:
  132. return
  133. if e.__cause__ is not None:
  134. filter_traceback(e.__cause__)
  135. if e.__context__ is not None:
  136. filter_traceback(e.__context__)
  137. # If a user has a file that matches one of these, they're out of luck.
  138. BAD_FILES = [
  139. "/triton/compiler/code_generator.py",
  140. "/ast.py",
  141. ]
  142. BAD_FILES = [bad_file.replace("/", os.sep) for bad_file in BAD_FILES]
  143. tb = e.__traceback__
  144. frames = []
  145. while tb is not None:
  146. if not any(f for f in BAD_FILES if tb.tb_frame.f_code.co_filename.endswith(f)):
  147. frames.append(tb)
  148. tb = tb.tb_next
  149. for (cur_frame, next_frame) in zip(frames, frames[1:]):
  150. cur_frame.tb_next = next_frame
  151. if not frames:
  152. e.__traceback__ = None
  153. else:
  154. frames[-1].tb_next = None
  155. e.__traceback__ = frames[0]
  156. class CompileTimer:
  157. def __init__(self) -> None:
  158. self.start: float = time.perf_counter()
  159. self.ir_initialization_end: float | None = None
  160. self.lowering_stage_ends: list[tuple[str, float]] = []
  161. self.store_results_end: float | None = None
  162. def finished_ir_initialization(self) -> None:
  163. self.ir_initialization_end = time.perf_counter()
  164. def stage_finished(self, stage_name: str) -> None:
  165. self.lowering_stage_ends.append((stage_name, time.perf_counter()))
  166. def end(self) -> knobs.CompileTimes:
  167. timestamp = time.perf_counter()
  168. if self.ir_initialization_end is None:
  169. self.ir_initialization_end = timestamp
  170. else:
  171. self.store_results_end = timestamp
  172. def delta(start: float, end: float | None) -> int:
  173. if end is None:
  174. return 0
  175. return int((end - start) * 1000000)
  176. lowering_stage_durations = []
  177. stage_start = self.ir_initialization_end
  178. for stage_name, stage_end in self.lowering_stage_ends:
  179. lowering_stage_durations.append((stage_name, delta(stage_start, stage_end)))
  180. stage_start = stage_end
  181. return knobs.CompileTimes(
  182. ir_initialization=delta(self.start, self.ir_initialization_end),
  183. lowering_stages=lowering_stage_durations,
  184. store_results=delta(stage_start, self.store_results_end),
  185. )
  186. def compile(src, target=None, options=None, _env_vars=None):
  187. compilation_listener = knobs.compilation.listener
  188. if compilation_listener:
  189. timer = CompileTimer()
  190. if target is None:
  191. target = driver.active.get_current_target()
  192. assert isinstance(target, GPUTarget), "target must be of GPUTarget type"
  193. backend = make_backend(target)
  194. ir_source = not isinstance(src, ASTSource)
  195. # create backend
  196. if ir_source:
  197. assert isinstance(src, str), "source must be either AST or a filepath"
  198. context = ir.context()
  199. src = IRSource(src, context, backend)
  200. extra_options = src.parse_options()
  201. options = backend.parse_options(dict(options or dict(), **extra_options))
  202. # create cache manager
  203. env_vars = get_cache_invalidating_env_vars() if _env_vars is None else _env_vars
  204. key = get_cache_key(src, backend, options, env_vars=env_vars)
  205. hash = hashlib.sha256(key.encode("utf-8")).hexdigest()
  206. fn_cache_manager = get_cache_manager(hash)
  207. # For dumping/overriding only hash the source as we want it to be independent of triton
  208. # core changes to make it easier to track kernels by hash.
  209. enable_override = knobs.compilation.override
  210. enable_ir_dump = knobs.compilation.dump_ir
  211. store_only_binary = knobs.compilation.store_binary_only
  212. fn_override_manager = get_override_manager(src.hash()) if enable_override else None
  213. fn_dump_manager = get_dump_manager(src.hash()) if enable_ir_dump else None
  214. # Pre-truncate the file name here to avoid hitting the 255 character limit on common platforms.
  215. # The final file name in the cache will have a format of f"{filename}.{ext}.tmp.pid_{pid}_{uuid}".
  216. # A PID string can be 5-character long. A UUID string has typically 36 characters. Let's truncate
  217. # the file name to 150 characters to be safe.
  218. file_name = src.name[:150]
  219. metadata_filename = f"{file_name}.json"
  220. metadata_group = fn_cache_manager.get_group(metadata_filename) or {}
  221. metadata_path = metadata_group.get(metadata_filename)
  222. always_compile = knobs.compilation.always_compile
  223. if not always_compile and metadata_path is not None:
  224. # cache hit!
  225. res = CompiledKernel(src, metadata_group, hash)
  226. if compilation_listener:
  227. compilation_listener(
  228. src=src,
  229. metadata=res.metadata._asdict(),
  230. metadata_group=metadata_group,
  231. times=timer.end(),
  232. cache_hit=True,
  233. )
  234. return res
  235. # initialize metadata
  236. metadata = {
  237. "hash": hash,
  238. "target": target,
  239. **options.__dict__,
  240. **env_vars,
  241. }
  242. metadata["triton_version"] = __version__
  243. # run compilation pipeline and populate metadata
  244. stages = dict()
  245. backend.add_stages(stages, options, src.language)
  246. first_stage = list(stages.keys()).index(src.ext)
  247. # when the source is an IR file, don't apply the passes related to this stage. This makes it easier to write IR level tests.
  248. if ir_source:
  249. first_stage += 1
  250. # For IRSource, we have already grabbed the context + called both
  251. # ir.load_dialects and backend.load_dialects.
  252. if not isinstance(src, IRSource):
  253. context = ir.context()
  254. ir.load_dialects(context)
  255. backend.load_dialects(context)
  256. codegen_fns = backend.get_codegen_implementation(options)
  257. module_map = backend.get_module_map()
  258. try:
  259. module = src.make_ir(target, options, codegen_fns, module_map, context)
  260. except Exception as e:
  261. filter_traceback(e)
  262. raise
  263. if ir_source:
  264. ir_filename = f"{file_name}.{src.ext}"
  265. metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
  266. else:
  267. ir_filename = f"{file_name}.source"
  268. metadata_group[ir_filename] = fn_cache_manager.put(module, ir_filename)
  269. use_ir_loc = knobs.compilation.use_ir_loc
  270. if ir_source and use_ir_loc:
  271. module.create_location_snapshot(src.path)
  272. print(f"Creating new locations for {src.path}")
  273. if compilation_listener:
  274. timer.finished_ir_initialization()
  275. for ext, compile_ir in list(stages.items())[first_stage:]:
  276. next_module = compile_ir(module, metadata)
  277. ir_filename = f"{file_name}.{ext}"
  278. if fn_override_manager is None:
  279. # Users can override kernels at scale by setting `ir_override` in autotune config
  280. # without TRITON_KERNEL_OVERRIDE
  281. if (ir_override := metadata.get("ir_override", None)) and ir_override.endswith(f".{ext}"):
  282. next_module = parse(ir_override, ext, context)
  283. elif full_name := fn_override_manager.get_file(ir_filename):
  284. print(f"\nOverriding kernel with file {full_name}")
  285. next_module = parse(full_name, ext, context)
  286. # If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
  287. if (not store_only_binary) or (ext in ("cubin", "hsaco", "json")):
  288. metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
  289. if fn_dump_manager is not None:
  290. fn_dump_manager.put(next_module, ir_filename)
  291. if ext == "cubin":
  292. sass = get_sass(next_module)
  293. fn_dump_manager.put(sass, file_name + ".sass")
  294. # use an env variable to parse ir from file
  295. if use_ir_loc == ext:
  296. ir_full_name = fn_cache_manager.get_file(ir_filename)
  297. next_module.create_location_snapshot(ir_full_name)
  298. print(f"Creating new locations for {ir_full_name}")
  299. module = next_module
  300. if compilation_listener:
  301. timer.stage_finished(ext)
  302. # write-back metadata
  303. metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename,
  304. binary=False)
  305. fn_cache_manager.put_group(metadata_filename, metadata_group)
  306. # notify any listener
  307. if compilation_listener:
  308. compilation_listener(src=src, metadata=metadata, metadata_group=metadata_group, times=timer.end(),
  309. cache_hit=False)
  310. # return handle to compiled kernel
  311. return CompiledKernel(src, metadata_group, hash)
  312. def make_backend(target: GPUTarget) -> BaseBackend:
  313. actives = [x.compiler for x in backends.values() if x.compiler.supports_target(target)]
  314. if len(actives) != 1:
  315. raise RuntimeError(
  316. f"{len(actives)} compatible backends for target ({target.backend}) ({actives}). There should only be one.")
  317. return actives[0](target)
  318. class LazyDict:
  319. def __init__(self, data):
  320. self.data = data
  321. self.extras = []
  322. def get(self):
  323. for func, args in self.extras:
  324. self.data = self.data | func(*args)
  325. self.extras.clear()
  326. return self.data
  327. def add(self, func, args):
  328. self.extras.append((func, args))
  329. class AsmDict(dict):
  330. def __missing__(self, key):
  331. if key == "sass":
  332. value = get_sass(self["cubin"])
  333. else:
  334. raise KeyError("Unknown key: '%s'" % key)
  335. self[key] = value
  336. return value
  337. def _raise_error(err, *args, **kwargs):
  338. raise copy.deepcopy(err)
  339. class CompiledKernel:
  340. def __init__(self, src, metadata_group, hash):
  341. from collections import namedtuple
  342. metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
  343. metadata = json.loads(metadata_path.read_text())
  344. # JSON serialization dumps the target as a dict. Restore it to a GPUTarget.
  345. target = metadata['target']
  346. metadata['target'] = GPUTarget(target['backend'], target['arch'], target['warp_size'])
  347. KernelMetadata = namedtuple('KernelMetadata', sorted(list(metadata.keys())))
  348. self.metadata = KernelMetadata(**metadata)
  349. backend = make_backend(self.metadata.target)
  350. self.packed_metadata = backend.pack_metadata(self.metadata)
  351. self.src = src
  352. self.hash = hash
  353. self.name = self.metadata.name
  354. # stores the text of each level of IR that was generated during compilation
  355. asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")]
  356. binary_ext = backend.binary_ext
  357. self.asm = AsmDict({
  358. file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
  359. for file in asm_files
  360. })
  361. self.metadata_group = metadata_group
  362. self.kernel = self.asm[binary_ext]
  363. # binaries are lazily initialized
  364. # because it involves doing runtime things
  365. # (e.g., checking amount of shared memory on current device)
  366. self.module = None
  367. self.function = None
  368. self._run = None
  369. def _init_handles(self):
  370. if self.module is not None:
  371. return
  372. def raise_(err):
  373. # clone the exception object so that the one saved in the closure
  374. # of the partial function below doesn't get assigned a stack trace
  375. # after the subsequent raise. otherwise, the CompiledKernel instance
  376. # saved in the (global) kernel cache will keep references to all the
  377. # locals in the traceback via the exception instance in the closure.
  378. cloned_err = copy.deepcopy(err)
  379. self._run = functools.partial(_raise_error, cloned_err)
  380. raise err
  381. device = driver.active.get_current_device()
  382. # create launcher
  383. self._run = driver.active.launcher_cls(self.src, self.metadata)
  384. # not enough shared memory to run the kernel
  385. max_shared = max_shared_mem(device)
  386. if self.metadata.shared > max_shared:
  387. raise_(OutOfResources(self.metadata.shared, max_shared, "shared memory"))
  388. if hasattr(self.metadata, "tmem_size") and self.metadata.tmem_size is not None:
  389. # Use blackwell max tmem size for now, this should be moved in device properties
  390. max_tmem_size = 512 # tmem size in number of columns
  391. if self.metadata.tmem_size > max_tmem_size:
  392. raise_(OutOfResources(self.metadata.tmem_size, max_tmem_size, "tensor memory"))
  393. if knobs.runtime.kernel_load_start_hook is not None:
  394. knobs.runtime.kernel_load_start_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
  395. # TODO: n_regs, n_spills should be metadata generated when calling `ptxas`
  396. self.module, self.function, self.n_regs, self.n_spills, self.n_max_threads = driver.active.utils.load_binary(
  397. self.name, self.kernel, self.metadata.shared, device)
  398. warp_size = driver.active.get_current_target().warp_size
  399. if self.metadata.num_warps * warp_size > self.n_max_threads:
  400. raise_(OutOfResources(self.metadata.num_warps * warp_size, self.n_max_threads, "threads"))
  401. if knobs.runtime.kernel_load_end_hook is not None:
  402. knobs.runtime.kernel_load_end_hook(self.module, self.function, self.name, self.metadata_group, self.hash)
  403. @property
  404. def run(self):
  405. if self._run is None:
  406. self._init_handles()
  407. return self._run
  408. def launch_metadata(self, grid, stream, *args):
  409. if knobs.runtime.launch_enter_hook is None:
  410. return None
  411. self._init_handles()
  412. ret = LazyDict({"name": self.name, "function": self.function, "stream": stream})
  413. if not isinstance(self.src, ASTSource) or self.src.fn.launch_metadata is None:
  414. return ret
  415. arg_dict = {name: arg for name, arg in zip(self.src.fn.arg_names, args)}
  416. ret.add(self.src.fn.launch_metadata, (grid, self.metadata, arg_dict))
  417. return ret
  418. def __getitem__(self, grid):
  419. self._init_handles()
  420. def runner(*args, stream=None):
  421. if stream is None:
  422. device = driver.active.get_current_device()
  423. stream = driver.active.get_current_stream(device)
  424. launch_metadata = self.launch_metadata(grid, stream, *args)
  425. self.run(grid[0], grid[1], grid[2], stream, self.function, self.packed_metadata, launch_metadata,
  426. knobs.runtime.launch_enter_hook, knobs.runtime.launch_exit_hook, *args)
  427. return runner