gen_vmap_plumbing.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. from __future__ import annotations
  2. import textwrap
  3. from dataclasses import dataclass
  4. from typing import TYPE_CHECKING
  5. from torchgen.api.translate import translate
  6. from torchgen.api.types import DispatcherSignature
  7. from torchgen.context import method_with_native_function
  8. from torchgen.model import (
  9. Argument,
  10. BaseTy,
  11. BaseType,
  12. FunctionSchema,
  13. ListType,
  14. NativeFunction,
  15. OptionalType,
  16. Return,
  17. SchemaKind,
  18. Type,
  19. )
  20. from torchgen.utils import mapMaybe
  21. if TYPE_CHECKING:
  22. from collections.abc import Sequence
  23. def is_tensor(typ: Type) -> bool:
  24. return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor
  25. def is_optional_tensor(typ: Type) -> bool:
  26. return isinstance(typ, OptionalType) and is_tensor(typ.elem)
  27. def is_tensor_list(typ: Type) -> bool:
  28. return isinstance(typ, ListType) and is_tensor(typ.elem)
  29. def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
  30. result = f"""\
  31. auto [{name}_value, {name}_bdim] = unwrapTensorAtLevel({name}, {cur_level_var});"""
  32. return textwrap.dedent(result).split("\n")
  33. def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
  34. result = f"""\
  35. std::optional<Tensor> {name}_value;
  36. std::optional<int64_t> {name}_bdim;
  37. if ({name}) {{
  38. std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var});
  39. }}"""
  40. return textwrap.dedent(result).split("\n")
  41. def gen_unwraps(
  42. flat_arguments: Sequence[Argument], cur_level_var: str
  43. ) -> tuple[str, list[str]]:
  44. arg_names = [a.name for a in flat_arguments]
  45. arg_types = [a.type for a in flat_arguments]
  46. tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)]
  47. optional_tensors = [
  48. name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ)
  49. ]
  50. unwraps = []
  51. for tensor in tensors:
  52. unwraps += unwrap_tensor(tensor, cur_level_var)
  53. for opt_tensor in optional_tensors:
  54. unwraps += unwrap_optional_tensor(opt_tensor, cur_level_var)
  55. unwrap_code = "\n".join(unwraps)
  56. unwrapped_arg_list = []
  57. for arg in arg_names:
  58. if arg in tensors or arg in optional_tensors:
  59. unwrapped_arg_list += [f"{arg}_value", f"{arg}_bdim"]
  60. else:
  61. unwrapped_arg_list.append(arg)
  62. return unwrap_code, unwrapped_arg_list
  63. def gen_case_where_all_bdims_are_none(
  64. outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str
  65. ) -> str:
  66. conditions = []
  67. flat_args = schema.arguments.flat_all
  68. for arg in flat_args:
  69. if not arg.type.is_tensor_like():
  70. continue
  71. conditions.append(f"!isBatchedAtLevel({arg.name}, {cur_level_var})")
  72. sig = DispatcherSignature.from_schema(schema)
  73. translated_args = ", ".join(
  74. e.expr for e in translate(outer_sig.arguments(), sig.arguments())
  75. )
  76. return f"""\
  77. if ({" && ".join(conditions)}) {{
  78. return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args});
  79. }}"""
  80. def gen_returns(
  81. returns: tuple[Return, ...], cur_level_var: str, results_var: str
  82. ) -> str:
  83. idx = 0
  84. wrapped_returns = []
  85. for ret in returns:
  86. if is_tensor(ret.type):
  87. wrapped_returns.append(
  88. f"makeBatched(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
  89. )
  90. idx += 2
  91. elif is_tensor_list(ret.type):
  92. wrapped_returns.append(
  93. f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
  94. )
  95. idx += 2
  96. else:
  97. wrapped_returns.append(f"std::get<{idx}>({results_var})")
  98. idx += 1
  99. if len(wrapped_returns) == 1:
  100. result = f"return {wrapped_returns[0]};"
  101. else:
  102. result = f"return std::make_tuple({', '.join(wrapped_returns)});"
  103. return result
  104. def accepts_at_least_one_tensor_input(schema: FunctionSchema) -> bool:
  105. return any(a.type.is_tensor_like() for a in schema.arguments.flat_all)
  106. def is_mutated_arg(argument: Argument) -> bool:
  107. return argument.annotation is not None and argument.annotation.is_write
  108. def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
  109. # Assumptions:
  110. # - only one argument is being modified in-place
  111. # - the argument that is being modified in-place is the first argument
  112. # - all returns are either Tensor, tuple of Tensor, or TensorList
  113. schema = native_function.func
  114. sig = DispatcherSignature.from_schema(schema)
  115. returns = schema.returns
  116. # Check assumptions. If these are invalid we return None
  117. # and punt the work to handle them to the future.
  118. if schema.kind() != SchemaKind.inplace:
  119. raise AssertionError(f"Expected inplace schema, got {schema.kind()}")
  120. if not is_mutated_arg(schema.arguments.flat_all[0]):
  121. return None
  122. if len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) != 1:
  123. return None
  124. # Only support cases where all returns are Tensors or vector<Tensor>
  125. if len(returns) == 0:
  126. return None
  127. if not all(is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns):
  128. return None
  129. if not accepts_at_least_one_tensor_input(schema):
  130. return None
  131. cur_level_var = "cur_level"
  132. unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
  133. bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
  134. return f"""\
  135. template <typename batch_rule_t, batch_rule_t batch_rule>
  136. {sig.decl(name=schema.name.unambiguous_name() + "_generated_plumbing")} {{
  137. c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  138. auto maybe_layer = maybeCurrentDynamicLayer();
  139. vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing");
  140. int64_t {cur_level_var} = maybe_layer->layerId();
  141. {textwrap.indent(bdims_all_none_case, " ")}
  142. {textwrap.indent(unwraps, " ")}
  143. batch_rule({", ".join(unwrapped_arg_list)});
  144. return {schema.arguments.flat_all[0].name};
  145. }}"""
  146. def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
  147. schema = native_function.func
  148. sig = DispatcherSignature.from_schema(schema)
  149. cur_level_var = "cur_level"
  150. unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
  151. bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
  152. return f"""\
  153. template <typename batch_rule_t, batch_rule_t batch_rule>
  154. {sig.decl(name=schema.name.unambiguous_name() + "_generated_plumbing")} {{
  155. c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  156. auto maybe_layer = maybeCurrentDynamicLayer();
  157. vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns");
  158. int64_t {cur_level_var} = maybe_layer->layerId();
  159. {textwrap.indent(bdims_all_none_case, " ")}
  160. {textwrap.indent(unwraps, " ")}
  161. batch_rule({", ".join(unwrapped_arg_list)});
  162. }}"""
  163. def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
  164. schema = native_function.func
  165. sig = DispatcherSignature.from_schema(schema)
  166. returns = schema.returns
  167. # Only support cases where all returns are Tensors or vector<Tensor>
  168. if not accepts_at_least_one_tensor_input(schema):
  169. return None
  170. if len(returns) == 0:
  171. return gen_vmap_plumbing_no_returns(native_function)
  172. return_symint_overrides = [
  173. "_scaled_dot_product_flash_attention",
  174. "_scaled_dot_product_cudnn_attention",
  175. "_scaled_dot_product_flash_attention_quantized",
  176. ]
  177. if (
  178. not all(ret.type.is_tensor_like() for ret in returns)
  179. and schema.name.unambiguous_name() not in return_symint_overrides
  180. ):
  181. return None
  182. # in-place views need special handling
  183. if "inplace_view" in native_function.tags:
  184. return None
  185. if schema.kind() == SchemaKind.inplace:
  186. return gen_vmap_inplace_plumbing(native_function)
  187. # Don't support these (mutable, out, scratch)
  188. if schema.kind() != SchemaKind.functional:
  189. return None
  190. results_var = "results"
  191. cur_level_var = "cur_level"
  192. unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
  193. bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
  194. wrapped_returns = gen_returns(returns, cur_level_var, results_var)
  195. return f"""\
  196. template <typename batch_rule_t, batch_rule_t batch_rule>
  197. {sig.decl(name=schema.name.unambiguous_name() + "_generated_plumbing")} {{
  198. c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
  199. auto maybe_layer = maybeCurrentDynamicLayer();
  200. vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
  201. int64_t {cur_level_var} = maybe_layer->layerId();
  202. {textwrap.indent(bdims_all_none_case, " ")}
  203. {textwrap.indent(unwraps, " ")}
  204. auto {results_var} = batch_rule({", ".join(unwrapped_arg_list)});
  205. {wrapped_returns}
  206. }}"""
  207. @dataclass(frozen=True)
  208. class ComputeBatchRulePlumbing:
  209. @method_with_native_function
  210. def __call__(self, f: NativeFunction) -> str | None:
  211. result = gen_vmap_plumbing(f)
  212. return result
  213. def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str:
  214. body = "\n".join(list(mapMaybe(ComputeBatchRulePlumbing(), native_functions)))
  215. return f"""
  216. #pragma once
  217. #include <ATen/Operators.h>
  218. #include <ATen/functorch/PlumbingHelper.h>
  219. namespace at {{ namespace functorch {{
  220. {body}
  221. }}}} // namespace at::functorch
  222. """