jiterator.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # mypy: allow-untyped-defs
  2. import re
  3. from collections.abc import Callable
  4. import torch
  5. from torch import Tensor
  6. __all__: list[str] = []
  7. class _CodeParser:
  8. def __init__(self, code_string: str):
  9. optional_ws = r"\s*"
  10. required_ws = r"\s+"
  11. template_params = r"(?P<template_params>\<.+\>)"
  12. return_type = r"(?P<return_type>\w+)"
  13. function_name = r"(?P<function_name>\w+)"
  14. function_params = r"(?P<function_params>\(.+\))"
  15. function_body = r"(?P<function_body>\{.+\})"
  16. pattern = (
  17. optional_ws
  18. + "template"
  19. + optional_ws
  20. + template_params
  21. + optional_ws
  22. + return_type
  23. + required_ws
  24. + function_name
  25. + optional_ws
  26. + function_params
  27. + optional_ws
  28. + function_body
  29. + optional_ws
  30. )
  31. result = re.match(
  32. pattern, code_string, re.DOTALL
  33. ) # DOTALL for matching multiline
  34. if result is None:
  35. raise Exception( # noqa: TRY002
  36. f"Couldn't parse code, please check correctness:\n {code_string}"
  37. )
  38. self.template_params = result["template_params"]
  39. self.return_type = result["return_type"]
  40. self.function_name = result["function_name"]
  41. self.function_params = result["function_params"]
  42. self.function_body = result["function_body"]
  43. class _JittedFunction:
  44. def __init__(
  45. self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
  46. ):
  47. self.code_string = code_string
  48. if not (return_by_ref or num_outputs == 1):
  49. raise AssertionError("Return by value only works for single output.")
  50. self.return_by_ref = return_by_ref
  51. self.num_outputs = num_outputs
  52. parsed_code = _CodeParser(code_string)
  53. self.kernel_name = parsed_code.function_name
  54. self.kwargs_dict = kwargs
  55. self.is_cuda_available = torch.cuda.is_available()
  56. def __call__(self, *tensors: Tensor, **kwargs):
  57. # Jiterator follow torch.cuda's lazy initialization behavior
  58. # Defer checking cuda's availability at the function invocation time
  59. if not self.is_cuda_available:
  60. raise AssertionError(
  61. "Jiterator is only supported on CUDA and ROCm GPUs, none are available."
  62. )
  63. if len(tensors) > 8:
  64. raise AssertionError(
  65. f"jiterator only supports up to 8 tensor inputs, got {len(tensors)}"
  66. )
  67. expanded_kwargs = self.kwargs_dict.copy()
  68. for key, value in kwargs.items():
  69. if key in self.kwargs_dict:
  70. expanded_kwargs[key] = value
  71. else:
  72. raise KeyError(f"{key} is not declared in function definition")
  73. return torch._C._cuda_jiterator_compile_and_launch_kernel(
  74. self.code_string,
  75. self.kernel_name,
  76. self.return_by_ref,
  77. self.num_outputs,
  78. tensors,
  79. expanded_kwargs,
  80. )
  81. def _create_jit_fn(code_string: str, **kwargs) -> Callable:
  82. """
  83. Create a jiterator-generated cuda kernel for an elementwise op.
  84. The code string has to be a valid CUDA function that describes the computation for a single element. The code
  85. string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
  86. into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
  87. local temp dir.
  88. Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion.
  89. Args:
  90. code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.
  91. kwargs (Dict, optional): Keyword arguments for generated function
  92. Example::
  93. code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
  94. jitted_fn = create_jit_fn(code_string, alpha=1.0)
  95. a = torch.rand(3, device="cuda")
  96. b = torch.rand(3, device="cuda")
  97. # invoke jitted function like a regular python function
  98. result = jitted_fn(a, b, alpha=3.14)
  99. code_string also allows multiple function definitions, and the last function will be treated as the entry function.
  100. Example::
  101. code_string = (
  102. "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
  103. )
  104. code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
  105. jitted_fn = create_jit_fn(code_string, val=0.0)
  106. a = torch.rand(3, device="cuda")
  107. b = torch.rand(3, device="cuda")
  108. # invoke jitted function like a regular python function
  109. result = jitted_fn(a, b) # using default val=0.0
  110. Jiterator can be used together with python registration to override an operator's cuda kernel.
  111. Following example is overriding gelu's cuda kernel with relu.
  112. Example::
  113. code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
  114. my_gelu = create_jit_fn(code_string)
  115. my_lib = torch.library.Library("aten", "IMPL")
  116. my_lib.impl("aten::gelu", my_gelu, "CUDA")
  117. # torch.nn.GELU and torch.nn.function.gelu are now overridden
  118. a = torch.rand(3, device="cuda")
  119. torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
  120. .. warning::
  121. This API is in beta and may change in future releases.
  122. .. warning::
  123. This API only supports up to 8 inputs and 1 output
  124. .. warning::
  125. All input tensors must live in CUDA device
  126. """
  127. return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
  128. def _create_multi_output_jit_fn(
  129. code_string: str, num_outputs: int, **kwargs
  130. ) -> Callable:
  131. """
  132. Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.
  133. Args:
  134. code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.
  135. num_outputs(int): number of outputs return by the kernel
  136. kwargs (Dict, optional): Keyword arguments for generated function
  137. Example::
  138. code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
  139. jitted_fn = create_jit_fn(code_string, alpha=1.0)
  140. a = torch.rand(3, device="cuda")
  141. b = torch.rand(3, device="cuda")
  142. # invoke jitted function like a regular python function
  143. result = jitted_fn(a, b, alpha=3.14)
  144. .. warning::
  145. This API is in beta and may change in future releases.
  146. .. warning::
  147. This API only supports up to 8 inputs and 8 outputs
  148. """
  149. return _JittedFunction(
  150. code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs
  151. )