register_dispatch_key.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035
  1. from __future__ import annotations
  2. import itertools
  3. import textwrap
  4. from dataclasses import dataclass
  5. from typing import Literal, TYPE_CHECKING
  6. from typing_extensions import assert_never
  7. import torchgen.api.cpp as cpp
  8. import torchgen.api.meta as meta
  9. import torchgen.api.structured as structured
  10. from torchgen.api.translate import translate
  11. from torchgen.api.types import (
  12. BaseCType,
  13. Binding,
  14. ConstRefCType,
  15. CppSignature,
  16. CppSignatureGroup,
  17. DispatcherSignature,
  18. Expr,
  19. kernel_signature,
  20. MutRefCType,
  21. NamedCType,
  22. NativeSignature,
  23. tensorT,
  24. )
  25. from torchgen.context import method_with_native_function, native_function_manager
  26. from torchgen.model import (
  27. Argument,
  28. BackendIndex,
  29. DeviceCheckType,
  30. DispatchKey,
  31. gets_generated_out_inplace_wrapper,
  32. is_cuda_dispatch_key,
  33. NativeFunction,
  34. NativeFunctionsGroup,
  35. SchemaKind,
  36. TensorOptionsArguments,
  37. )
  38. from torchgen.utils import mapMaybe, Target
  39. if TYPE_CHECKING:
  40. from torchgen.selective_build.selector import SelectiveBuilder
  41. def gen_registration_headers(
  42. backend_index: BackendIndex,
  43. per_operator_headers: bool,
  44. rocm: bool,
  45. ) -> list[str]:
  46. if per_operator_headers:
  47. headers = ["#include <ATen/ops/as_strided_native.h>"]
  48. else:
  49. headers = ["#include <ATen/NativeFunctions.h>"]
  50. if backend_index.dispatch_key in (DispatchKey.CPU, DispatchKey.Meta):
  51. headers.append("#include <ATen/EmptyTensor.h>")
  52. elif backend_index.dispatch_key == DispatchKey.CUDA:
  53. if rocm:
  54. headers.append("#include <ATen/hip/EmptyTensor.h>")
  55. else:
  56. headers.append("#include <ATen/cuda/EmptyTensor.h>")
  57. elif backend_index.dispatch_key == DispatchKey.MPS:
  58. headers.append("#include <ATen/mps/EmptyTensor.h>")
  59. elif backend_index.dispatch_key == DispatchKey.XPU:
  60. # XPU specific, this header resides in third_party/torch-xpu-ops
  61. headers.append("#include <ATen/xpu/EmptyTensor.h>")
  62. elif backend_index.dispatch_key == DispatchKey.MTIA:
  63. headers.append("#include <ATen/native/mtia/EmptyTensor.h>")
  64. elif per_operator_headers:
  65. headers += [
  66. "#include <ATen/ops/empty.h>",
  67. "#include <ATen/ops/empty_strided.h>",
  68. "#include <ATen/ops/_copy_from_and_resize.h>",
  69. "#include <ATen/ops/_copy_from.h>",
  70. ]
  71. else:
  72. headers.append("#include <ATen/Functions.h>")
  73. headers.append("#include <c10/macros/Macros.h>")
  74. return headers
  75. def gen_empty_impl_names(
  76. backend_index: BackendIndex,
  77. ) -> tuple[str | None, str | None]:
  78. empty_impl = None
  79. empty_strided_impl = None
  80. if backend_index.dispatch_key in (
  81. DispatchKey.Meta,
  82. DispatchKey.CPU,
  83. DispatchKey.CUDA,
  84. DispatchKey.MPS,
  85. DispatchKey.XPU,
  86. DispatchKey.MTIA,
  87. ):
  88. dispatch = str(backend_index.dispatch_key).lower()
  89. empty_impl = f"at::detail::empty_{dispatch}"
  90. empty_strided_impl = f"at::detail::empty_strided_{dispatch}"
  91. elif backend_index.dispatch_key in (
  92. DispatchKey.CompositeExplicitAutogradNonFunctional,
  93. DispatchKey.QuantizedCPU,
  94. DispatchKey.QuantizedCUDA,
  95. DispatchKey.XPU,
  96. ):
  97. empty_impl = "at::empty"
  98. empty_strided_impl = "at::empty_strided"
  99. return empty_impl, empty_strided_impl
  100. def gen_create_out_helper(backend_index: BackendIndex) -> list[str]:
  101. if backend_index.dispatch_key == DispatchKey.Meta:
  102. empty_options = "options.device(at::kMeta)"
  103. else:
  104. empty_options = "options"
  105. empty_impl, empty_strided_impl = gen_empty_impl_names(backend_index)
  106. if empty_impl is None:
  107. return []
  108. return [
  109. f"""
  110. Tensor create_out(IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
  111. if (strides.empty()) {{
  112. return {empty_impl}(sizes, {empty_options});
  113. }} else {{
  114. return {empty_strided_impl}(sizes, strides, {empty_options});
  115. }}
  116. }}
  117. """
  118. ]
  119. def gen_maybe_create_proxy_helper(backend_index: BackendIndex) -> list[str]:
  120. _, empty_strided_impl = gen_empty_impl_names(backend_index)
  121. return (
  122. []
  123. if empty_strided_impl is None
  124. else [
  125. f"""
  126. std::optional<Tensor> maybe_create_proxy(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {{
  127. if (out.strides() != strides) {{
  128. return {empty_strided_impl}(sizes, strides, options);
  129. }}
  130. return std::nullopt;
  131. }}
  132. """
  133. ]
  134. )
  135. def gen_resize_out_helper(backend_index: BackendIndex) -> list[str]:
  136. if backend_index.dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
  137. # The function isn't used by this key (since only functional ops have a kernel for this key),
  138. # so we need to not include it to avoid a defined-but-not-used error.
  139. return []
  140. return [
  141. """
  142. void resize_out(const Tensor &out, IntArrayRef sizes, IntArrayRef strides, const TensorOptions &options) {
  143. TORCH_CHECK(options.dtype() == out.dtype(),
  144. "Expected out tensor to have dtype ", options.dtype(), ", but got ", out.dtype(), " instead");
  145. TORCH_CHECK(options.device() == out.device(),
  146. "Expected out tensor to have device ", options.device(), ", but got ", out.device(), " instead");
  147. const bool resized = at::native::resize_output(out, sizes);
  148. // Only restride if a resize occurred; otherwise we ignore the (advisory)
  149. // strides from the meta function and directly use the output tensor's
  150. // preexisting strides
  151. if (resized) {
  152. if (!strides.empty()) {
  153. TORCH_INTERNAL_ASSERT(!options.memory_format_opt().has_value());
  154. // TODO: avoid the redispatch here
  155. out.as_strided_(sizes, strides);
  156. } else if (options.memory_format_opt().has_value()) {
  157. out.unsafeGetTensorImpl()->empty_tensor_restride(*options.memory_format_opt());
  158. }
  159. }
  160. }
  161. """
  162. ]
  163. def gen_check_inplace_helper(backend_index: BackendIndex) -> list[str]:
  164. return [
  165. """
  166. void check_inplace(const Tensor &self, IntArrayRef sizes, const TensorOptions &options) {
  167. // These checks are needed on those operators that:
  168. // 1) don't use 'TensorIterator' (e.g. 'addmm' and 'baddbmm')
  169. // 2) have particular typing rules (e.g. 'cumsum' and 'cumprod')
  170. // For other operators (e.g. 'add'), 'TensorIterator' already checks
  171. // these things separately.
  172. TORCH_CHECK(options.dtype() == self.dtype(),
  173. "Bad in-place call: ",
  174. "input tensor dtype ", self.dtype(), " and output tensor dtype ", options.dtype(), " should match");
  175. TORCH_CHECK(options.device() == self.device(),
  176. "Bad in-place call: ",
  177. "input tensor device ", self.device(), " and output tensor device ", options.device(), " should match");
  178. TORCH_CHECK(sizes == self.sizes(),
  179. "Bad in-place call: ",
  180. "input tensor size ", self.sizes(), " and output tensor size ", sizes, " should match");
  181. }
  182. """
  183. ]
  184. def gen_registration_helpers(backend_index: BackendIndex) -> list[str]:
  185. return [
  186. 'C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function")',
  187. *gen_create_out_helper(backend_index),
  188. *gen_resize_out_helper(backend_index),
  189. *gen_check_inplace_helper(backend_index),
  190. *gen_maybe_create_proxy_helper(backend_index),
  191. "C10_DIAGNOSTIC_POP()",
  192. ]
  193. # Generates Register{dispatch}.cpp (e.g., RegisterCPU.cpp).
  194. #
  195. # - The primary function of this file is to register all of the
  196. # implementations for the given dispatch key to the dispatcher,
  197. # so they are available for use in PyTorch. If dispatch is
  198. # None, we generate schema (def) registrations and catchall
  199. # registrations.
  200. # - The secondary function of this file is to generate a wrapper
  201. # around functions. In CPUType these wrappers do nothing
  202. # (and should be removed), but in other cases they handle
  203. # DeviceGuard. A small extra benefit of wrappers is they
  204. # are not overloaded, so they can be used in the registration
  205. # API without having to disambiguate which overload you want
  206. # (as would be the case if you directly registered native::
  207. # functions).
  208. # - The tertiary function of this file is to generate *static*
  209. # cpp API bindings which can be used to bypass dispatcher
  210. # directly to kernels, but with user-friendly cpp-style API
  211. @dataclass(frozen=True)
  212. class RegisterDispatchKey:
  213. backend_index: BackendIndex
  214. target: Literal[
  215. Target.ANONYMOUS_DEFINITION,
  216. Target.NAMESPACED_DEFINITION,
  217. Target.NAMESPACED_DECLARATION,
  218. Target.REGISTRATION,
  219. ]
  220. # Selector object to determine which operators to generate
  221. # registration code for.
  222. selector: SelectiveBuilder
  223. # Whether or not we are actually code-genning for ROCm
  224. rocm: bool
  225. # Whether or not to generate symint registrations or not. External users
  226. # of codegen who don't care about symints can set this to false to get
  227. # non-SymInt codegen
  228. symint: bool
  229. # The class that all unstructured native functions live under. This is used to improve
  230. # compiler error messages when a kernel writer adds a native function with the wrong signature.
  231. # This is only used in unstructured kernels, since structured kernels already live in a class.
  232. # Finally, this field is currently Optional because it is only used by external backends.
  233. # It would be nice if we can add the same logic to in-tree kernels too, but that requires updating
  234. # all of the existing kernel signatures scattered across aten/src/ATen/native.
  235. class_method_name: str | None
  236. # Only set to true in lightweight dispatch. If lightweight dispatch is enabled we are registering
  237. # operators into JIT op registry, thus we need to avoid generating code to register into the dispatcher.
  238. skip_dispatcher_op_registration: bool
  239. @staticmethod
  240. def gen_device_check(
  241. type: DeviceCheckType, args: list[Argument], method_name: str
  242. ) -> str:
  243. if type == DeviceCheckType.NoCheck:
  244. return " // No device check\n"
  245. device_check = "std::optional<Device> common_device = std::nullopt;\n"
  246. device_check += "(void)common_device; // Suppress unused variable warning\n"
  247. for arg in args:
  248. # Only tensor like arguments are eligible
  249. if arg.type.is_tensor_like():
  250. device_check += f"""
  251. c10::impl::check_and_update_common_device(common_device, {arg.name}, "{method_name}", "{arg.name}");"""
  252. return device_check
  253. @method_with_native_function
  254. def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
  255. if isinstance(f, NativeFunctionsGroup):
  256. g: NativeFunctionsGroup = f
  257. # Note: We call gen_structured() if the operator is marked structured, regardless of the backend.
  258. # gen_structured() has special logic to handle auto-generated kernels.
  259. if g.structured:
  260. return self.gen_structured(g)
  261. else:
  262. return list(
  263. mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions())
  264. )
  265. elif isinstance(f, NativeFunction):
  266. r = self.gen_unstructured(f)
  267. return [] if r is None else [r]
  268. else:
  269. assert_never(f)
  270. def wrapper_kernel_sig(
  271. self, f: NativeFunction
  272. ) -> NativeSignature | DispatcherSignature:
  273. # The prefix is just to ensure uniqueness. The Dispatcher API doesn't guarantee unique kernel names.
  274. return DispatcherSignature.from_schema(
  275. f.func,
  276. prefix=f"wrapper_{self.backend_index.dispatch_key}_{f.func.name.overload_name}_",
  277. symint=self.symint,
  278. )
  279. def gen_out_inplace_wrapper(
  280. self, f: NativeFunction, g: NativeFunctionsGroup | None
  281. ) -> str | None:
  282. if g is None:
  283. return None
  284. k = f.func.kind()
  285. if k is SchemaKind.inplace:
  286. copy_op = "at::_copy_from"
  287. elif k is SchemaKind.out:
  288. copy_op = "at::_copy_from_and_resize"
  289. else:
  290. raise AssertionError("gen_out_inplace_wrapper called on a functional op")
  291. sig = self.wrapper_kernel_sig(f)
  292. name = sig.name()
  293. func_res = f"{name}_tmp"
  294. return_names = cpp.return_names(f)
  295. if len(return_names) > 1:
  296. updates = "\n ".join(
  297. f"{copy_op}(std::get<{i}>({func_res}), {ret_name});"
  298. for i, ret_name in enumerate(return_names)
  299. )
  300. returns = f"{sig.returns_type().cpp_type()}({', '.join(return_names)})"
  301. elif len(return_names) == 1:
  302. ret_name = return_names[0]
  303. updates = f"{copy_op}({func_res}, {ret_name});"
  304. returns = ret_name
  305. else:
  306. if len(f.func.arguments.out) != 1:
  307. raise AssertionError(
  308. f"Expected exactly 1 out argument, got {len(f.func.arguments.out)}"
  309. )
  310. returns = ""
  311. out_arg = f.func.arguments.out[0]
  312. if out_arg.type.is_list_like():
  313. updates = f"""\
  314. for (int64_t i = 0; i < {func_res}.size(); ++i) {{
  315. {copy_op}({func_res}[i], {out_arg.name}[i]);
  316. }}"""
  317. else:
  318. updates = f"{copy_op}({func_res}, {out_arg.name});"
  319. functional_sig = self.wrapper_kernel_sig(g.functional)
  320. wrapper_name = sig.name()
  321. return f"""\
  322. {sig.defn(name=wrapper_name)} {{
  323. auto {func_res} = {functional_sig.name()}({", ".join(e.expr for e in translate(sig.arguments(), functional_sig.arguments()))});
  324. {updates}
  325. return {returns};
  326. }}
  327. """
  328. def gen_structured(self, g: NativeFunctionsGroup) -> list[str]:
  329. metadata = self.backend_index.get_kernel(g)
  330. if self.backend_index.dispatch_key == DispatchKey.Meta:
  331. if self.backend_index.has_kernel(g.out):
  332. raise AssertionError(
  333. "Do not explicitly specify Meta dispatch key on structured "
  334. "functions, they will be automatically generated for you"
  335. )
  336. elif (
  337. self.backend_index.dispatch_key
  338. == DispatchKey.CompositeExplicitAutogradNonFunctional
  339. ):
  340. if self.backend_index.has_kernel(g.out):
  341. raise AssertionError(
  342. "Do not explicitly specify CompositeExplicitAutograd dispatch key on "
  343. "structured functions, they will be automatically generated for you"
  344. )
  345. elif metadata is None or not metadata.structured:
  346. return list(mapMaybe(lambda f: self.gen_unstructured(f, g), g.functions()))
  347. structured_gen = StructuredRegisterDispatchKey(
  348. self.backend_index,
  349. self.target,
  350. self.selector,
  351. self.rocm,
  352. self.symint,
  353. self.class_method_name,
  354. self.skip_dispatcher_op_registration,
  355. g,
  356. )
  357. return list(mapMaybe(structured_gen.gen_one, g.functions()))
  358. def gen_unstructured(
  359. self, f: NativeFunction, g: NativeFunctionsGroup | None = None
  360. ) -> str | None:
  361. with native_function_manager(f):
  362. inplace_meta = False
  363. gets_out_inplace_wrapper = False
  364. if not self.backend_index.has_kernel(f):
  365. if (
  366. self.backend_index.dispatch_key == DispatchKey.Meta
  367. and f.func.kind() is SchemaKind.inplace
  368. and
  369. # Defer to composites for meta implementation
  370. not f.has_composite_kernel
  371. and
  372. # Inplace list operations are not supported
  373. len(f.func.returns) == 1
  374. ):
  375. inplace_meta = True
  376. elif (
  377. not self.backend_index.use_out_as_primary
  378. and g is not None
  379. and gets_generated_out_inplace_wrapper(f, g, self.backend_index)
  380. ):
  381. # We want to generate inplace/out wrappers, that don't have a kernel for the backend.
  382. gets_out_inplace_wrapper = True
  383. else:
  384. return None
  385. if f.manual_kernel_registration:
  386. return None
  387. if (
  388. self.target is Target.REGISTRATION
  389. and not self.selector.is_native_function_selected(f)
  390. ):
  391. return None
  392. sig = self.wrapper_kernel_sig(f)
  393. name = sig.name()
  394. returns_type = sig.returns_type().cpp_type()
  395. args = sig.arguments()
  396. args_str = ", ".join(a.defn() for a in args)
  397. # See Note [Direct dispatch bindings]
  398. cpp_sig_group = CppSignatureGroup.from_native_function(
  399. f, method=False, fallback_binding=False
  400. )
  401. # TODO: dedupe this with the structured codegen
  402. if self.target is Target.NAMESPACED_DECLARATION:
  403. result = ""
  404. for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
  405. result += f"TORCH_API {cpp_sig.decl()};\n"
  406. return result
  407. elif self.target is Target.NAMESPACED_DEFINITION:
  408. def generate_defn(cpp_sig: CppSignature) -> str:
  409. return f"""
  410. {cpp_sig.defn()} {{
  411. return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
  412. }}
  413. """
  414. result = ""
  415. for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
  416. result += generate_defn(cpp_sig)
  417. return result
  418. elif self.target is Target.ANONYMOUS_DEFINITION:
  419. # short circuit for inplace_meta
  420. if inplace_meta:
  421. if f.func.arguments.self_arg is None:
  422. raise AssertionError(
  423. "Expected self_arg to be non-None for inplace_meta"
  424. )
  425. self_arg_name = f.func.arguments.self_arg.argument.name
  426. # TODO: handle in place on tensor list
  427. return f"""
  428. {returns_type} {name}({args_str}) {{
  429. TORCH_CHECK_NOT_IMPLEMENTED({self_arg_name}.is_meta(),
  430. "Cannot inplace into non-meta tensor with meta tensor argument");
  431. return {self_arg_name};
  432. }}
  433. """
  434. # short circuit for generated inplace/out wrappers
  435. if gets_out_inplace_wrapper:
  436. return self.gen_out_inplace_wrapper(f, g)
  437. metadata = self.backend_index.get_kernel(f)
  438. if metadata is None:
  439. return None
  440. if self.class_method_name is None:
  441. impl_name = f"{metadata.cpp_namespace}::{metadata.kernel}"
  442. else:
  443. impl_name = f"{metadata.cpp_namespace}::{self.class_method_name}::{metadata.kernel}"
  444. kernel_sig = kernel_signature(f, self.backend_index)
  445. args_exprs_str = ", ".join(
  446. e.expr
  447. for e in translate(
  448. sig.arguments(), kernel_sig.arguments(), method=False
  449. )
  450. )
  451. device_check = " // No device check\n"
  452. # Backends that require device guards presumably also require device checks.
  453. if self.backend_index.device_guard:
  454. device_check_args = itertools.chain(
  455. f.func.arguments.out, f.func.arguments.flat_positional
  456. )
  457. device_check = RegisterDispatchKey.gen_device_check(
  458. f.device_check, list(device_check_args), name
  459. )
  460. device_guard = "// DeviceGuard omitted" # default
  461. if f.device_guard and self.backend_index.device_guard:
  462. has_tensor_options = any(
  463. isinstance(a, TensorOptionsArguments)
  464. for a in f.func.arguments.non_out
  465. )
  466. if has_tensor_options:
  467. # kernel is creating a tensor
  468. device_guard = """
  469. const DeviceGuard device_guard(device_or_default(device));"""
  470. # CUDA requires special handling
  471. if is_cuda_dispatch_key(self.backend_index.dispatch_key):
  472. device_guard = f"globalContext().lazyInitDevice(c10::DeviceType::CUDA);\n{device_guard}"
  473. else:
  474. # kernel is operating on existing tensors
  475. # There is precedence for which argument we use to do
  476. # device guard. This describes the precedence order.
  477. self_arg = (
  478. [f.func.arguments.self_arg.argument]
  479. if f.func.arguments.self_arg is not None
  480. else []
  481. )
  482. candidate_args = itertools.chain(
  483. self_arg,
  484. f.func.arguments.out,
  485. f.func.arguments.flat_positional,
  486. )
  487. # Only tensor like arguments are eligible
  488. device_of = next(
  489. (
  490. f"{a.name}"
  491. for a in candidate_args
  492. if a.type.is_tensor_like()
  493. ),
  494. None,
  495. )
  496. if device_of is not None:
  497. device_guard = f"const OptionalDeviceGuard device_guard(device_of({device_of}));"
  498. return f"""\
  499. namespace {{
  500. {returns_type} {name}({args_str}) {{
  501. {device_check}
  502. {device_guard}
  503. return {impl_name}({args_exprs_str});
  504. }}
  505. }} // anonymous namespace
  506. """
  507. elif self.target is Target.REGISTRATION:
  508. if f.manual_kernel_registration or self.skip_dispatcher_op_registration:
  509. return None
  510. else:
  511. payload = f"TORCH_FN({name})"
  512. return f'm.impl("{f.func.name}",\n{payload});\n'
  513. else:
  514. assert_never(self.target)
  515. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  516. #
  517. # STRUCTURED
  518. #
  519. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  520. @dataclass(frozen=True)
  521. class StructuredRegisterDispatchKey(RegisterDispatchKey):
  522. g: NativeFunctionsGroup
  523. def gen_class_set_output_functions(
  524. self, k: SchemaKind, parent_class: str, generate_super: bool
  525. ) -> str:
  526. if generate_super:
  527. set_output_super = f"{parent_class}::set_output_raw_strided(output_idx, sizes, strides, options, names);"
  528. else:
  529. set_output_super = ""
  530. def gen_set_output_function(name: str, maybe_create_proxy: bool) -> str:
  531. return f"""
  532. void set_output_{name}(
  533. int64_t output_idx, IntArrayRef sizes, IntArrayRef strides,
  534. TensorOptions options, DimnameList names
  535. ) override {{
  536. {textwrap.indent(self.gen_class_set_output_body(k, maybe_create_proxy), " ")}
  537. if (!names.empty()) {{
  538. namedinference::propagate_names(outputs_[output_idx], names);
  539. }}
  540. // super must happen after, so that downstream can use maybe_get_output
  541. // to retrieve the output
  542. {textwrap.indent(set_output_super, " ")}
  543. }}
  544. """
  545. return f"""
  546. {gen_set_output_function("strided", maybe_create_proxy=True)}
  547. {gen_set_output_function("raw_strided", maybe_create_proxy=False)}
  548. """
  549. def gen_class_set_output_body(self, k: SchemaKind, maybe_create_proxy: bool) -> str:
  550. if self.backend_index.dispatch_key in [
  551. DispatchKey.CUDA,
  552. DispatchKey.MPS,
  553. DispatchKey.XPU,
  554. DispatchKey.CompositeExplicitAutogradNonFunctional,
  555. ]:
  556. maybe_set_guard = """
  557. auto current_device = guard_.current_device();
  558. if (C10_UNLIKELY(current_device.has_value())) {
  559. TORCH_INTERNAL_ASSERT(*current_device == options.device(),
  560. "structured kernels don't support multi-device outputs");
  561. } else {
  562. guard_.reset_device(options.device());
  563. }
  564. """
  565. maybe_set_guard_line = maybe_set_guard + "\n"
  566. else:
  567. maybe_set_guard_line = maybe_set_guard = ""
  568. if maybe_create_proxy:
  569. create_proxy = """
  570. auto maybe_proxy = maybe_create_proxy(out, sizes, strides, options);
  571. if (C10_UNLIKELY(maybe_proxy.has_value())) {
  572. proxy_outputs_[output_idx] = std::move(maybe_proxy).value();
  573. }
  574. """
  575. else:
  576. create_proxy = ""
  577. if k is SchemaKind.functional:
  578. if self.backend_index.dispatch_key not in (
  579. DispatchKey.Meta,
  580. DispatchKey.CPU,
  581. DispatchKey.CUDA,
  582. DispatchKey.MPS,
  583. DispatchKey.XPU,
  584. DispatchKey.MTIA,
  585. DispatchKey.CompositeExplicitAutogradNonFunctional,
  586. ):
  587. raise AssertionError(
  588. f"Unexpected dispatch key {self.backend_index.dispatch_key} "
  589. "for functional schema"
  590. )
  591. return f"""{maybe_set_guard_line}
  592. outputs_[output_idx] = create_out(sizes, strides, options);"""
  593. elif k is SchemaKind.inplace:
  594. return f"""{maybe_set_guard_line}
  595. const auto& out = outputs_[output_idx].get();
  596. check_inplace(out, sizes, options);
  597. {create_proxy}"""
  598. elif k is SchemaKind.out:
  599. return f"""{maybe_set_guard_line}
  600. const auto& out = outputs_[output_idx].get();
  601. resize_out(out, sizes, strides, options);
  602. {create_proxy}"""
  603. elif k is SchemaKind.mutable or k is SchemaKind.scratch:
  604. raise AssertionError(
  605. f"{k} structured operators are currently not supported"
  606. )
  607. else:
  608. assert_never(k)
  609. # returns the definition of a ctor, as well as how to construct
  610. # this class to a variable named op
  611. def gen_class_ctor(self, k: SchemaKind, class_name: str, returns: int) -> str:
  612. if k is SchemaKind.functional:
  613. return ""
  614. elif k is SchemaKind.inplace:
  615. # TODO: Make sure out argument is guaranteed to be self
  616. return f"{class_name}(Tensor& self) : outputs_{{std::ref(self)}} {{}}"
  617. elif k is SchemaKind.out:
  618. out_args = ", ".join(f"Tensor& out{i}" for i in range(returns))
  619. out_refs = ", ".join(f"std::ref(out{i})" for i in range(returns))
  620. return f"{class_name}({out_args}) : outputs_{{ {out_refs} }} {{}}"
  621. elif k is SchemaKind.mutable or k is SchemaKind.scratch:
  622. raise AssertionError(
  623. f"{k} structured operators are currently not supported"
  624. )
  625. else:
  626. assert_never(k)
  627. def gen_class(
  628. self,
  629. f: NativeFunction,
  630. k: SchemaKind,
  631. *,
  632. class_name: str,
  633. parent_class: str,
  634. generate_super: bool,
  635. ) -> str:
  636. if k is SchemaKind.functional:
  637. output_type = "Tensor"
  638. output_value = "outputs_[output_idx]"
  639. proxy_field = ""
  640. elif k is SchemaKind.inplace:
  641. output_type = "std::reference_wrapper<Tensor>"
  642. output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
  643. proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
  644. elif k is SchemaKind.out:
  645. output_type = "std::reference_wrapper<Tensor>"
  646. output_value = "proxy_outputs_[output_idx].has_value() ? *proxy_outputs_[output_idx] : outputs_[output_idx].get()"
  647. proxy_field = f"std::array<::std::optional<Tensor>, {len(f.func.returns)}> proxy_outputs_;"
  648. else:
  649. raise RuntimeError(f"Unsupported SchemaKind {k}")
  650. if self.backend_index.dispatch_key == DispatchKey.CUDA:
  651. guard_field = "c10::cuda::OptionalCUDAGuard guard_;"
  652. elif (
  653. self.backend_index.dispatch_key
  654. == DispatchKey.CompositeExplicitAutogradNonFunctional
  655. ):
  656. guard_field = "c10::OptionalDeviceGuard guard_;"
  657. elif self.backend_index.dispatch_key == DispatchKey.MPS:
  658. # TODO: Move to OptionalMPSGuard.
  659. guard_field = "c10::OptionalDeviceGuard guard_;"
  660. elif self.backend_index.dispatch_key == DispatchKey.XPU:
  661. guard_field = "c10::OptionalDeviceGuard guard_;"
  662. elif self.backend_index.dispatch_key == DispatchKey.MTIA:
  663. guard_field = "c10::OptionalDeviceGuard guard_;"
  664. else:
  665. guard_field = ""
  666. indent = " " * 4
  667. class_ctor_str = self.gen_class_ctor(k, class_name, len(f.func.returns))
  668. lines = (
  669. f"struct {class_name} final : public {parent_class} {{",
  670. f"{textwrap.indent(class_ctor_str, indent)}",
  671. f"{textwrap.indent(self.gen_class_set_output_functions(k, parent_class, generate_super), indent)}",
  672. " const Tensor& maybe_get_output(int64_t output_idx) override {",
  673. f" return {output_value};\n", # type: ignore[possibly-undefined] # TODO: audit
  674. " }",
  675. # type: ignore[possibly-undefined] # TODO: audit
  676. f" std::array<{output_type}, {len(f.func.returns)}> outputs_;",
  677. f"{textwrap.indent(proxy_field, indent)}", # type: ignore[possibly-undefined] # TODO: audit
  678. f"{textwrap.indent(guard_field, indent)}",
  679. "};",
  680. )
  681. return "\n".join(line for line in lines if line)
  682. @method_with_native_function
  683. def gen_one(self, f: NativeFunction) -> str | None:
  684. if f.manual_kernel_registration:
  685. raise AssertionError(
  686. f"Function {f.func.name} has manual_kernel_registration=True"
  687. )
  688. if (
  689. self.target is Target.REGISTRATION
  690. and not self.selector.is_native_function_selected(f)
  691. ):
  692. return None
  693. # TODO: Now, there is something interesting going on here. In the code below,
  694. # we generate CompositeExplicitAutogradNonFunctional implementations of functional and inplace
  695. # based on the out implementation. But in fact, out is definable by
  696. # functional too (just not very efficiently), and this is honestly the
  697. # MORE likely situation for a backend implementer. How do we pick?
  698. # Well, taking a page from Haskell type classes and default methods,
  699. # we could conceivably register a circular definition (out in terms
  700. # of functional, and functional in terms of out) and just require
  701. # someone to implement one or the other. We'd have to do a little bit
  702. # of work to not register one of these "weak" definitions unless there
  703. # is a strong definition somewhere in the DAG! So it's not implemented yet.
  704. if (
  705. self.backend_index.dispatch_key
  706. == DispatchKey.CompositeExplicitAutogradNonFunctional
  707. and f.func.kind() is SchemaKind.out
  708. ):
  709. # Never generate a default implementation for out, that's what you
  710. # have to define as a backend implementer
  711. return None
  712. # Note [Direct dispatch bindings]
  713. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  714. # Signature of the non-dispatched function we'll expose in a header
  715. # (e.g., at::cpu::add). We don't generate methods (TODO: do this
  716. # when CPUTensor class is a thing); nor do we generate fallback
  717. # bindings for manual_cpp_binding functions.
  718. cpp_sig_group = CppSignatureGroup.from_native_function(
  719. f, method=False, fallback_binding=False
  720. )
  721. # Signature of the wrapper function we'll register to the dispatcher
  722. kern = self.backend_index.get_kernel(f)
  723. sig = NativeSignature(
  724. f.func,
  725. prefix=f"wrapper_{self.backend_index.dispatch_key}_",
  726. symint=kern is not None and kern.supports_symint(),
  727. )
  728. if self.target is Target.NAMESPACED_DECLARATION:
  729. result = ""
  730. for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
  731. result += f"TORCH_API {cpp_sig.decl()};\n"
  732. return result
  733. elif self.target is Target.NAMESPACED_DEFINITION:
  734. def generate_defn(cpp_sig: CppSignature) -> str:
  735. return f"""
  736. {cpp_sig.defn()} {{
  737. return {sig.name()}({", ".join(e.expr for e in translate(cpp_sig.arguments(), sig.arguments()))});
  738. }}
  739. """
  740. result = ""
  741. for cpp_sig in cpp_sig_group.signatures(symint=self.symint):
  742. result += generate_defn(cpp_sig)
  743. return result
  744. elif self.target is Target.ANONYMOUS_DEFINITION:
  745. k = f.func.kind()
  746. # Construct the body of the wrapper function with signature sig
  747. sig_body = []
  748. # We'll use context to keep track of any variables we've brought
  749. # into scope while generating code
  750. context: list[Binding | Expr] = list(sig.arguments())
  751. # Initialize the class corresponding to this structured
  752. # operator; feeding it the output argument(s) if it is known
  753. if self.backend_index.dispatch_key is DispatchKey.Meta:
  754. class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
  755. parent_class = f"at::meta::structured_{meta.name(self.g)}"
  756. elif (
  757. self.backend_index.dispatch_key
  758. is DispatchKey.CompositeExplicitAutogradNonFunctional
  759. ):
  760. # TODO: dedup this branch
  761. class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
  762. parent_class = f"at::meta::structured_{meta.name(self.g)}"
  763. else:
  764. metadata = self.backend_index.get_kernel(self.g)
  765. if metadata is None:
  766. raise AssertionError(
  767. f"No kernel metadata found for {self.g.functional.func.name}"
  768. )
  769. class_name = f"structured_{metadata.kernel}_{k.name}"
  770. parent_class = f"{metadata.cpp_namespace}::structured_{metadata.kernel}"
  771. if self.backend_index.device_guard:
  772. device_check_args = itertools.chain(
  773. f.func.arguments.out, f.func.arguments.flat_positional
  774. )
  775. sig_body.append(
  776. RegisterDispatchKey.gen_device_check(
  777. f.device_check, list(device_check_args), sig.name()
  778. )
  779. )
  780. if k is SchemaKind.functional:
  781. sig_body.append(f"{class_name} op;")
  782. elif k is SchemaKind.inplace:
  783. sig_body.append(f"{class_name} op(self);")
  784. elif k is SchemaKind.out:
  785. out_args_str = ", ".join(a.name for a in f.func.arguments.out)
  786. sig_body.append(f"{class_name} op({out_args_str});")
  787. # Translate the input native arguments into structured
  788. # arguments for the meta call
  789. meta_exprs = ", ".join(
  790. e.expr
  791. for e in translate(
  792. context, structured.meta_arguments(self.g), method=False
  793. )
  794. )
  795. if self.g.out.precomputed:
  796. # If this function group has precomputed elements, the meta function
  797. # returns a struct containing them which must be saved so that it
  798. # can be unpacked when generating code to call the impl.
  799. sig_body.append(f"auto precompute = op.meta({meta_exprs});")
  800. # Put all of the contents of the precompute struct into the context
  801. # so that translate will be able to return the correct args for the
  802. # call to the impl.
  803. precomputed_values = [
  804. *self.g.out.precomputed.replace.values(),
  805. self.g.out.precomputed.add,
  806. ]
  807. for precomputed_elems in precomputed_values:
  808. context.extend(
  809. Expr(
  810. expr=f"precompute.{arg.name}",
  811. type=structured.argument_type(arg, binds=arg.name),
  812. )
  813. for arg in precomputed_elems
  814. )
  815. # Add a use of the precompute struct so FB internal compilers don't
  816. # complain that there is an unused variable.
  817. sig_body.append("(void)precompute;")
  818. else:
  819. sig_body.append(f"op.meta({meta_exprs});")
  820. # After running meta, op.outputs_ is guaranteed to be valid;
  821. # add it to the context
  822. out_args = structured.out_arguments(self.g)
  823. for i, out_arg in enumerate(out_args):
  824. if ConstRefCType(BaseCType(tensorT)) != out_arg.nctype.type:
  825. raise AssertionError(
  826. f"Expected out_arg type to be ConstRefCType(BaseCType(tensorT)), "
  827. f"got {out_arg.nctype.type}"
  828. )
  829. if k is SchemaKind.out:
  830. expr = f"op.maybe_get_output({i})"
  831. else:
  832. expr = f"op.outputs_[{i}]"
  833. context.append(
  834. Expr(
  835. expr=expr,
  836. # TODO: Stop hardcoding that the output type is a Tensor. Note
  837. # that for the codegen here this is fine because outputs_ is
  838. # hardcoded to be tensor already
  839. type=NamedCType(
  840. out_arg.nctype.name, MutRefCType(BaseCType(tensorT))
  841. ),
  842. )
  843. )
  844. # With the expanded context, do the impl call (if not a meta
  845. # function)
  846. if (
  847. self.backend_index.dispatch_key
  848. == DispatchKey.CompositeExplicitAutogradNonFunctional
  849. ):
  850. # TODO: https://github.com/pytorch/pytorch/issues/53023
  851. out_sig_group = CppSignatureGroup.from_native_function(
  852. self.g.out, method=False, fallback_binding=f.manual_cpp_binding
  853. )
  854. out_sig = out_sig_group.most_faithful_signature()
  855. api_name = out_sig.name()
  856. out_exprs = ", ".join(
  857. e.expr
  858. for e in translate(context, out_sig.arguments(), method=False)
  859. )
  860. # TODO: I think this means structured won't work with method
  861. # only functions (but maybe you're saved by faithful? iunno.)
  862. # NB: Originally I wrote this as an at::redispatch call, but
  863. # I got in trouble because that meant I needed a DispatchKeySet
  864. # in the wrapper function, which meant I needed a DispatchKeySet
  865. # in the DispatchKeyFunctions declarations, but the defined API
  866. # there does NOT permit a dispatch key set. I think you can
  867. # probably unwind this by calling some function to do the TLS
  868. # fetch and get the DispatchKeySet when you don't have it, but
  869. # I didn't do it for this version
  870. sig_body.append(f"at::{api_name}({out_exprs});")
  871. elif self.backend_index.dispatch_key != DispatchKey.Meta:
  872. impl_exprs = ", ".join(
  873. e.expr
  874. for e in translate(
  875. context, structured.impl_arguments(self.g), method=False
  876. )
  877. )
  878. sig_body.append(f"op.impl({impl_exprs});")
  879. # Go over each output, and check if there is a proxy created for it.
  880. # If so, copy it over to the original output.
  881. if k is SchemaKind.out or k is SchemaKind.inplace:
  882. for i in range(len(f.func.returns)):
  883. sig_body.append(
  884. f"if (op.proxy_outputs_[{i}].has_value()) op.outputs_[{i}].get().copy_(*op.proxy_outputs_[{i}]);"
  885. )
  886. # Destructively return the final tensors
  887. # TODO: Do this in translate instead
  888. if k is SchemaKind.functional:
  889. if len(f.func.returns) == 1:
  890. ret_expr = "std::move(op.outputs_[0])" # small optimization
  891. else:
  892. moved = ", ".join(
  893. f"std::move(op.outputs_[{i}])"
  894. for i in range(len(f.func.returns))
  895. )
  896. ret_expr = f"std::make_tuple({moved})"
  897. elif k is SchemaKind.inplace:
  898. ret_expr = "self"
  899. elif k is SchemaKind.out:
  900. if len(f.func.returns) == 1:
  901. ret_expr = f.func.arguments.out[0].name
  902. else:
  903. refs = ", ".join(a.name for a in f.func.arguments.out)
  904. ret_expr = f"std::forward_as_tuple({refs})"
  905. sig_body.append(f"return {ret_expr};") # type: ignore[possibly-undefined] # TODO: audit
  906. sig_body_str = "\n".join(sig_body)
  907. # For an overview of what this template code looks like, see
  908. # https://github.com/pytorch/rfcs/pull/9
  909. return f"""\
  910. {
  911. self.gen_class(
  912. f,
  913. k,
  914. class_name=class_name,
  915. parent_class=parent_class,
  916. generate_super=self.g.out.structured_inherits is not None,
  917. )
  918. }
  919. {sig.defn()} {{
  920. {sig_body_str}
  921. }}
  922. """
  923. elif self.target is Target.REGISTRATION:
  924. return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
  925. else:
  926. assert_never(self.target)
  927. # Silence mypy's "Missing return statement" error
  928. return None