| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697 |
- import functools
- import os
- import inspect
- import subprocess
- import tempfile
- import triton
- from triton.compiler import ASTSource, make_backend
- from triton.backends.compiler import GPUTarget
- from triton.experimental.gluon._runtime import GluonASTSource
- from triton.runtime.jit import create_function_from_signature
- from triton._C.libtriton import ir
- # ===-----------------------------------------------------------------------===#
- # filecheck_test
- # ===-----------------------------------------------------------------------===#
- # Stub target for testing the frontend.
- stub_target = GPUTarget("cuda", 100, 32)
- triton_dir = os.path.dirname(__file__)
- filecheck_path = os.path.join(triton_dir, "FileCheck")
- class MatchError(ValueError):
- def __init__(self, message, module_str):
- super().__init__(message)
- self.module_str = module_str
- def __str__(self):
- return f"{super().__str__()}\n{self.module_str}"
- def run_filecheck(name, module_str, check_template):
- with tempfile.TemporaryDirectory() as tempdir:
- temp_module = os.path.join(tempdir, "module")
- with open(temp_module, "w") as temp:
- temp.write(module_str)
- temp_expected = os.path.join(tempdir, "expected")
- with open(temp_expected, "w") as temp:
- temp.write(check_template)
- try:
- subprocess.check_output(
- [filecheck_path, temp_expected, "--input-file", temp_module, "--dump-input-context=50"],
- stderr=subprocess.STDOUT)
- except subprocess.CalledProcessError as error:
- # Decode using OS native encoding
- decoded = error.output.decode().replace("\r\n", "\n")
- raise ValueError(decoded)
- def run_parser(kernel_fn, args=(), kwargs={}, target=stub_target):
- if "sanitize_overflow" not in kwargs:
- kwargs = dict(kwargs)
- kwargs["sanitize_overflow"] = False
- backend = make_backend(target)
- binder = create_function_from_signature(
- kernel_fn.signature,
- kernel_fn.params,
- backend,
- )
- bound_args, specialization, options = binder(*args, **kwargs)
- options, signature, constexprs, attrs = kernel_fn._pack_args(backend, kwargs, bound_args, specialization, options)
- source_cls = GluonASTSource if kernel_fn.is_gluon() else ASTSource
- src = source_cls(kernel_fn, signature, constexprs, attrs)
- context = ir.context()
- ir.load_dialects(context)
- backend.load_dialects(context)
- codegen_fns = backend.get_codegen_implementation(options)
- module_map = backend.get_module_map()
- module = src.make_ir(target, options, codegen_fns, module_map, context)
- return module
- def run_filecheck_test(kernel_fn):
- assert isinstance(kernel_fn, triton.runtime.JITFunction)
- check_template = inspect.getsource(kernel_fn.fn)
- if check_template is None:
- raise ValueError("kernel function must have a docstring with FileCheck template")
- mlir_module = run_parser(kernel_fn)
- run_filecheck("placeholder", mlir_module.str_nodebug(), check_template)
- def filecheck_test(fn):
- @functools.wraps(fn)
- def test_fn():
- run_filecheck_test(fn)
- return test_fn
|