_filecheck.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import functools
  2. import os
  3. import inspect
  4. import subprocess
  5. import tempfile
  6. import triton
  7. from triton.compiler import ASTSource, make_backend
  8. from triton.backends.compiler import GPUTarget
  9. from triton.experimental.gluon._runtime import GluonASTSource
  10. from triton.runtime.jit import create_function_from_signature
  11. from triton._C.libtriton import ir
  12. # ===-----------------------------------------------------------------------===#
  13. # filecheck_test
  14. # ===-----------------------------------------------------------------------===#
  15. # Stub target for testing the frontend.
  16. stub_target = GPUTarget("cuda", 100, 32)
  17. triton_dir = os.path.dirname(__file__)
  18. filecheck_path = os.path.join(triton_dir, "FileCheck")
  19. class MatchError(ValueError):
  20. def __init__(self, message, module_str):
  21. super().__init__(message)
  22. self.module_str = module_str
  23. def __str__(self):
  24. return f"{super().__str__()}\n{self.module_str}"
  25. def run_filecheck(name, module_str, check_template):
  26. with tempfile.TemporaryDirectory() as tempdir:
  27. temp_module = os.path.join(tempdir, "module")
  28. with open(temp_module, "w") as temp:
  29. temp.write(module_str)
  30. temp_expected = os.path.join(tempdir, "expected")
  31. with open(temp_expected, "w") as temp:
  32. temp.write(check_template)
  33. try:
  34. subprocess.check_output(
  35. [filecheck_path, temp_expected, "--input-file", temp_module, "--dump-input-context=50"],
  36. stderr=subprocess.STDOUT)
  37. except subprocess.CalledProcessError as error:
  38. # Decode using OS native encoding
  39. decoded = error.output.decode().replace("\r\n", "\n")
  40. raise ValueError(decoded)
  41. def run_parser(kernel_fn, args=(), kwargs={}, target=stub_target):
  42. if "sanitize_overflow" not in kwargs:
  43. kwargs = dict(kwargs)
  44. kwargs["sanitize_overflow"] = False
  45. backend = make_backend(target)
  46. binder = create_function_from_signature(
  47. kernel_fn.signature,
  48. kernel_fn.params,
  49. backend,
  50. )
  51. bound_args, specialization, options = binder(*args, **kwargs)
  52. options, signature, constexprs, attrs = kernel_fn._pack_args(backend, kwargs, bound_args, specialization, options)
  53. source_cls = GluonASTSource if kernel_fn.is_gluon() else ASTSource
  54. src = source_cls(kernel_fn, signature, constexprs, attrs)
  55. context = ir.context()
  56. ir.load_dialects(context)
  57. backend.load_dialects(context)
  58. codegen_fns = backend.get_codegen_implementation(options)
  59. module_map = backend.get_module_map()
  60. module = src.make_ir(target, options, codegen_fns, module_map, context)
  61. return module
  62. def run_filecheck_test(kernel_fn):
  63. assert isinstance(kernel_fn, triton.runtime.JITFunction)
  64. check_template = inspect.getsource(kernel_fn.fn)
  65. if check_template is None:
  66. raise ValueError("kernel function must have a docstring with FileCheck template")
  67. mlir_module = run_parser(kernel_fn)
  68. run_filecheck("placeholder", mlir_module.str_nodebug(), check_template)
  69. def filecheck_test(fn):
  70. @functools.wraps(fn)
  71. def test_fn():
  72. run_filecheck_test(fn)
  73. return test_fn