native_function_generation.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662
  1. from __future__ import annotations
  2. import string
  3. from collections import defaultdict
  4. from typing import TYPE_CHECKING
  5. import torchgen.api.dispatcher as dispatcher
  6. from torchgen.api.translate import translate
  7. from torchgen.api.types import Binding, DispatcherSignature, Expr
  8. from torchgen.context import with_native_function
  9. from torchgen.model import (
  10. Annotation,
  11. Argument,
  12. BackendIndex,
  13. BackendMetadata,
  14. BaseOperatorName,
  15. BaseTy,
  16. BaseType,
  17. DEFAULT_KERNEL_NAMESPACE,
  18. DeviceCheckType,
  19. DispatchKey,
  20. FunctionSchema,
  21. NativeFunction,
  22. NativeFunctionsGroup,
  23. OperatorName,
  24. Return,
  25. SchemaKind,
  26. Variant,
  27. )
  28. from torchgen.utils import concatMap
  29. if TYPE_CHECKING:
  30. from collections.abc import Sequence
  31. # See Note: [Out ops with functional variants that don't get grouped properly]
  32. OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
  33. # This has a functional variant, but it's currently marked private.
  34. # This function should be marked private as well (*_backward ops aren't exposed to python anyway).
  35. "adaptive_avg_pool3d_backward.grad_input",
  36. # There's a functional variant, _slow_conv2d_backward.output_mask, that isn't grouped properly.
  37. # Maybe we can kill this operator in favor of convolution_backward?
  38. "_slow_conv2d_backward.grad_input",
  39. ]
  40. # See Note: [Mutable ops that cannot get an out variant]
  41. MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
  42. # should be out=?
  43. "_cummax_helper",
  44. # should be out=?
  45. "_cummin_helper",
  46. ]
  47. # All of these operators don't have any tensor like returns
  48. FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT = [
  49. "_assert_async", # no return
  50. "_assert_async.msg", # no return
  51. "_assert_tensor_metadata", # no return
  52. "_cslt_sparse_mm_search", # returns an int
  53. "_assert_scalar", # no return
  54. "_dimI", # returns an int
  55. "_dimV", # returns an int
  56. "_has_same_storage_numel", # returns a boolean
  57. "_linalg_check_errors", # no return
  58. "_local_scalar_dense", # returns a Scalar
  59. "_nested_tensor_from_mask_left_aligned", # returns a boolean
  60. "_nnz", # returns an int
  61. "_use_cudnn_ctc_loss", # returns a boolean
  62. "_use_cudnn_ctc_loss.Tensor", # returns a boolean
  63. "_use_miopen_ctc_loss", # returns a boolean
  64. "_use_miopen_ctc_loss.Tensor", # returns a boolean
  65. "_validate_compressed_sparse_indices", # no return
  66. "allclose", # returns a boolean
  67. "dense_dim", # returns an int
  68. "equal", # returns a boolean
  69. "is_coalesced", # returns an boolean
  70. "is_pinned", # returns a boolean
  71. "is_same_size", # returns a boolean
  72. "is_set_to", # returns a boolean
  73. "q_per_channel_axis", # returns an int
  74. "q_scale", # returns a float
  75. "q_zero_point", # returns an int
  76. "qscheme", # returns a QScheme
  77. "record_stream", # no return
  78. "sparse_dim", # returns an int
  79. "sym_constrain_range", # no return
  80. "sym_constrain_range_for_size", # no return
  81. "_nested_tensor_storage_offsets", # returns a vector of ints
  82. "_chunk_grad_outputs_efficient_attention", # returns a bool
  83. "_fused_sdp_choice", # returns an int
  84. "_print", # no return
  85. "_sink_tokens", # no return
  86. "_nested_get_ragged_idx", # returns an int
  87. ]
  88. INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY = [
  89. # polygamma and polygamma.out both exist, but have a
  90. # pre-self arg (while polygamma_ does not)
  91. # We should either fix this schema so it can be grouped properly,
  92. # or allow the codegen to generate new functional/out= NativeFunctions for this op
  93. # (which would require changing its overload name to prevent overload ambiguity).
  94. "polygamma_"
  95. ]
  96. # Groups "similar" NativeFunctions together
  97. # example add.Tensor, add_.Tensor, add.out
  98. # "similar" NativeFunctions are all expected to have an identical `signature()`,
  99. # But have differing SchemaKinds.
  100. def pre_group_native_functions(
  101. native_functions: Sequence[NativeFunction],
  102. ) -> dict[FunctionSchema, dict[SchemaKind, NativeFunction]]:
  103. pre_grouped_native_functions: dict[
  104. FunctionSchema, dict[SchemaKind, NativeFunction]
  105. ] = defaultdict(dict)
  106. for f in native_functions:
  107. d = pre_grouped_native_functions[f.func.signature()]
  108. if f.func.kind() in d:
  109. raise AssertionError(f"Duplicate schema kind {f.func.kind()} for {f.func}")
  110. d[f.func.kind()] = f
  111. return pre_grouped_native_functions
  112. # Returns the out variant overload name given a base function overload name
  113. def get_expected_out_variant_overload_name(overload_name: str | None) -> str:
  114. return "out" if not overload_name else f"{overload_name}_out"
  115. # Helper function: given an inplace FunctionSchema, generate its corresponding out= variant
  116. # Example before:
  117. # _add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
  118. # Example after:
  119. # _add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out)
  120. def self_to_out_signature(func: FunctionSchema) -> FunctionSchema:
  121. # Generating an out= schema from an inplace schema.
  122. if func.kind() != SchemaKind.inplace:
  123. raise AssertionError(f"Expected inplace schema, got {func.kind()}")
  124. if func.arguments.self_arg is None:
  125. raise AssertionError("Expected self_arg to be non-None")
  126. # The new out= schema has:
  127. # - a new out argument with the same type as "func" (but with a mutable annotation)
  128. # - The returns (if any) now alias the out= argument instead of "func"
  129. # - an "out" overload name
  130. return FunctionSchema(
  131. name=func.name.remove_inplace().with_overload(
  132. get_expected_out_variant_overload_name(func.name.overload_name)
  133. ),
  134. arguments=func.arguments.remove_self_annotation().with_out_args(
  135. [
  136. Argument(
  137. name="out",
  138. type=func.arguments.self_arg.argument.type,
  139. default=None,
  140. annotation=func.arguments.self_arg.argument.annotation,
  141. )
  142. ]
  143. ),
  144. returns=func.returns,
  145. )
  146. # Helper function: given a functional FunctionSchema, generate its corresponding out= variant
  147. # Example before:
  148. # _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None,
  149. # bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
  150. # Example after:
  151. # _to_copy._out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None,
  152. # Tensor(a!) out) -> Tensor(a!)
  153. def functional_to_out_signature(func: FunctionSchema) -> FunctionSchema:
  154. # Generating an out= schema from a functional schema.
  155. if func.kind() != SchemaKind.functional:
  156. raise AssertionError(f"Expected functional schema, got {func.kind()}")
  157. new_returns, new_out_args = generate_out_args_from_schema(func)
  158. # The new out= schema has:
  159. # - one or more new out argument(s) with the same type as returns (but with a mutable annotation)
  160. # - The returns now alias the out= arguments
  161. # - an "_out" overload name
  162. return FunctionSchema(
  163. name=func.name.with_overload(
  164. get_expected_out_variant_overload_name(func.name.overload_name)
  165. ),
  166. arguments=func.arguments.signature().with_out_args(
  167. new_out_args,
  168. ),
  169. returns=tuple(new_returns),
  170. )
  171. # Helper function: given a function schema, generate corresponding out arguments, also the updated return annotations.
  172. def generate_out_args_from_schema(
  173. func: FunctionSchema,
  174. ) -> tuple[list[Return], list[Argument]]:
  175. # More of a sanity check - our existing restrictions on schemas should enforce that
  176. # mutable schema kinds never return their mutable arguments.
  177. if any(r.annotation is not None and r.annotation.is_write for r in func.returns):
  178. raise AssertionError("Mutable schema kinds should not return mutable arguments")
  179. tensorlike_rets = [r for r in func.returns if r.type.is_tensor_like()]
  180. if len(tensorlike_rets) == 0:
  181. raise AssertionError("Expected at least one tensor-like return")
  182. used_annotations = concatMap(
  183. lambda a: [] if a.annotation is None else a.annotation.alias_set,
  184. func.arguments.flat_all,
  185. )
  186. valid_annotations = [x for x in string.ascii_lowercase if x not in used_annotations]
  187. all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns)
  188. new_out_args: list[Argument] = []
  189. # The end result of new_returns is that:
  190. # - If every return is a plain tensor, then the new returns == the old returns, but with the out= alias annotations added.
  191. # - Otherwise, none of the out arguments show up in the returns (and we're only left with non-tensor-like returns, if any).
  192. new_returns: list[Return] = []
  193. for i, r in enumerate(func.returns):
  194. if r.type.is_tensor_like():
  195. new_out = Argument(
  196. name="out" if len(func.returns) == 1 else f"out{i}",
  197. type=r.type,
  198. default=None,
  199. annotation=Annotation.parse(f"{valid_annotations[i]}!"),
  200. )
  201. new_out_args.append(new_out)
  202. if all_rets_are_tensors:
  203. # The convention for out= schemas is that they only return their out arguments
  204. # if the return is a plain Tensor (or if it's a tuple of plain Tensors)
  205. new_ret = Return(
  206. name=None, type=new_out.type, annotation=new_out.annotation
  207. )
  208. new_returns.append(new_ret)
  209. else:
  210. new_returns.append(r)
  211. return new_returns, new_out_args
  212. # Helper function: given a mutable FunctionSchema, generate its corresponding out= variant
  213. # Example before:
  214. # _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950
  215. # Example after:
  216. # _fused_moving_avg_obs_fq_helper._out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) # noqa: B950
  217. def mutable_to_out_signature(func: FunctionSchema) -> FunctionSchema:
  218. # Generating an out= schema from a mutable schema.
  219. if func.kind() != SchemaKind.mutable:
  220. raise AssertionError(f"Expected mutable schema, got {func.kind()}")
  221. # The new out= schema has:
  222. # - Any non-aliased tensor-like returns are converted to mutable, aliased out= arguments
  223. # (if the argument is a tensor then we also return it for method chaining,
  224. # otherwise we return nothing)
  225. # - an "out" overload name
  226. #
  227. # Note that:
  228. # (1) This also means that we can *only* generate an out= variant from a mutable schema
  229. # if the mutable schema has at least one tensor-like non-aliasing return.
  230. # (2) The generated out= variant still has mutable positional arguments,
  231. # but if necessary we could probably add another out= variant that also
  232. # functionalizes the mutable arguments (a functional_out variant)
  233. new_returns, new_out_args = generate_out_args_from_schema(func)
  234. return FunctionSchema(
  235. name=func.name.remove_inplace().with_overload(
  236. get_expected_out_variant_overload_name(func.name.overload_name)
  237. ),
  238. arguments=func.arguments.with_out_args(new_out_args),
  239. returns=tuple(new_returns),
  240. )
  241. # This function, given function of one SchemaKind, as well as a target SchemaKind,
  242. # generates a new NativeFunction with the same properties, but using the target SchemaKind.
  243. # We only actually generate functions for either functional or out= SchemaKinds.
  244. # This function returns a tuple, with:
  245. # - The generated NativeFunction
  246. # - a dictionary of `BackendIndex` objects, describing which dispatch keys
  247. # we will generate kernels for, for the new NativeFunction.
  248. # Details are in the function, but we only generate composite kernels (in some cases) today.
  249. def generate_function(
  250. f: NativeFunction, k: SchemaKind
  251. ) -> tuple[NativeFunction, dict[DispatchKey, dict[OperatorName, BackendMetadata]]]:
  252. from torchgen.api import cpp
  253. if k == SchemaKind.functional:
  254. if f.func.kind() == SchemaKind.functional:
  255. raise AssertionError("Cannot generate functional from functional schema")
  256. # The new "functional" NativeFunction has:
  257. # - any mutable arguments have been converted into (immutable) returns.
  258. # (if a mutable argument was not also a return, it gets converted to one)
  259. # - "_functional" appended to the base name, ONLY IF this op has a mutable variant.
  260. # See Note [Overload Ambiguity With Functional Variants]
  261. # The default grouping logic in signature() actually already does this,
  262. # so we can piggy-back off it (but we still want return names)
  263. func = f.func.signature(keep_return_names=True).with_name(
  264. OperatorName(
  265. name=BaseOperatorName(
  266. base=f.func.name.name.base,
  267. inplace=False,
  268. dunder_method=f.func.name.name.dunder_method,
  269. # See Note [Overload Ambiguity With Functional Variants]
  270. functional_overload=f.func.kind() == SchemaKind.mutable,
  271. ),
  272. overload_name=f.func.name.overload_name,
  273. )
  274. )
  275. elif k == SchemaKind.out:
  276. # We generate out= ops mostly just so that we can pair up NativeFunctions into groups easily,
  277. # but at least today, there is no good reason to actually use them.
  278. # we'll generate a dispatcher entry for them, but won't actually register any kernels for them.
  279. if f.func.kind() == SchemaKind.inplace:
  280. func = self_to_out_signature(f.func)
  281. elif f.func.kind() == SchemaKind.mutable:
  282. func = mutable_to_out_signature(f.func)
  283. elif f.func.kind() == SchemaKind.functional:
  284. func = functional_to_out_signature(f.func)
  285. else:
  286. raise AssertionError(
  287. "We only bother generating out= functions from either inplace or mutable or functional variants"
  288. )
  289. else:
  290. raise AssertionError(
  291. "We currently only generate either functional or out= NativeFunctions"
  292. )
  293. # Generated kernel naming convention for out: <op_name>_<overload_name>. The reason for this is to
  294. # disambiguate operator with the same name but different overload name, e.g., `randn.names_out` and
  295. # `randn.generator_with_names_out`.
  296. kernel_name = (
  297. func.name.unambiguous_name()
  298. if func.kind() == SchemaKind.out
  299. else cpp.name(func)
  300. )
  301. if f.func.has_symint():
  302. kernel_name += "_symint"
  303. backend_metadata = {
  304. DispatchKey.CompositeExplicitAutograd: {
  305. func.name: BackendMetadata(
  306. kernel=kernel_name,
  307. structured=False,
  308. cpp_namespace=DEFAULT_KERNEL_NAMESPACE,
  309. )
  310. }
  311. }
  312. tags = {"generated"} | set(
  313. f.tags & {"nondeterministic_seeded", "view_copy", "pt2_compliant_tag"}
  314. )
  315. return (
  316. NativeFunction(
  317. func=func,
  318. use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
  319. # These generated fn's aren't meant to be user friendly- don't generate methods.
  320. variants={Variant.function},
  321. structured=False,
  322. structured_delegate=None,
  323. structured_inherits=None,
  324. precomputed=None,
  325. autogen=[],
  326. ufunc_inner_loop={},
  327. manual_kernel_registration=False,
  328. manual_cpp_binding=False,
  329. python_module=None,
  330. category_override=None,
  331. device_guard=False,
  332. device_check=DeviceCheckType.NoCheck,
  333. loc=f.loc,
  334. cpp_no_default_args=set(),
  335. is_abstract=f.is_abstract,
  336. has_composite_implicit_autograd_kernel=False,
  337. has_composite_implicit_autograd_nested_tensor_kernel=False,
  338. has_composite_explicit_autograd_kernel=True,
  339. has_composite_explicit_autograd_non_functional_kernel=False,
  340. # Every generated NativeFunction gets a "generated" tag, so it's easy to tell
  341. # which NativeFunction objects did not come directly from native_functions.yaml.
  342. tags=tags,
  343. namespace=f.namespace,
  344. ),
  345. backend_metadata,
  346. )
  347. # This function is responsible for adding generated NativeFunctions which don't appear
  348. # explicitly in the codegen.
  349. # You can inspect the full list of NativeFunctions yourself with the torchgen package, by running
  350. # torchgen.parse_native_yaml("aten/src/ATen/native/native_functions.yaml", "aten/src/ATen/native/tags.yaml")
  351. # (Maybe we should make a friendly API for this)
  352. #
  353. # Note: this function *mutates* its two inputs,
  354. # adding the new NativeFunctions / BackendMetadata to them
  355. def add_generated_native_functions(
  356. rs: list[NativeFunction],
  357. indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]],
  358. ) -> None:
  359. # The main code for generating new NativeFunctions
  360. # First we group of NativeFunctions by schema kind,
  361. # then we detect which ones are missing and generate them.
  362. pre_grouped_native_functions = pre_group_native_functions(rs)
  363. for d in pre_grouped_native_functions.values():
  364. has_functional = SchemaKind.functional in d
  365. has_inplace = SchemaKind.inplace in d
  366. has_mutable = SchemaKind.mutable in d
  367. has_out = SchemaKind.out in d
  368. is_core = any("core" in variant.tags for variant in d.values())
  369. # We automatically generate a few native functions that don't exist in the yaml, for a few reasons:
  370. # (1) If an operator has an inplace/out= variant but no functional variant, we can generate
  371. # a simple functional variant that the functionalization pass can consume.
  372. # (2) If an operator has an inplace or functional but no out= variant, we generate an out=
  373. # variant, mostly so we can easily pair up functions into NativeFunctionsGroup,
  374. # while maintaining the constraint that the out= variant is "required".
  375. if has_mutable or has_inplace or has_out or has_functional:
  376. # Don't bother generating functions trio's for native functions that bypass the dispatcher.
  377. are_manual = all(f.manual_cpp_binding for f in d.values())
  378. # Don't bother generating functional + out= variants for view operators
  379. # set_ is technically an inplace_view, but for now it is treated
  380. # as a normal inplace op in the codegen
  381. has_view_ops = any(
  382. f.is_view_op and str(f.func.name.name) != "set_" for f in d.values()
  383. )
  384. # Don't generate the other variants for non-core CompositeImplicitAutograd operators.
  385. # We could probably do this, but the main benefit of generating the function triplets
  386. # is for transforms that need them, and transforms don't need to act directly
  387. # on CompositeImplicitAutograd operators (since we let them decompose).
  388. are_composite_implicit = all(
  389. f.has_composite_implicit_autograd_kernel for f in d.values()
  390. )
  391. if are_manual or has_view_ops or are_composite_implicit and not is_core:
  392. continue
  393. if has_out and len(d.values()) == 1:
  394. # Note: [Out ops with functional variants that don't get grouped properly]
  395. # In theory we could validly have an out= operator in native_functions.yaml
  396. # that has no other variants.
  397. # But today, all of the operators where that's the case actually do have
  398. # functional variants, that we are just unable to pair up properly.
  399. # I think banning this all together is probably safer
  400. # (you can always add a functional variant yourself if you want to add a new out= operator).
  401. #
  402. # We should probably fix the existing cases; this check is to prevent us from adding more over time.
  403. if (
  404. str(d[SchemaKind.out].func.name)
  405. not in OUT_OPS_THAT_DONT_GET_GROUPED_PROPERLY
  406. ):
  407. raise AssertionError(
  408. f"Found an out= operator that we could not find any other variants of: {str(d[SchemaKind.out].func)}"
  409. )
  410. continue
  411. # Some inplace ops that have problematic schemas (that we should fix), which prevent us
  412. # from generating out= and functional variants
  413. if (
  414. has_inplace
  415. and str(d[SchemaKind.inplace].func.name)
  416. in INPLACE_OPS_THAT_DONT_GET_GROUPED_PROPERLY
  417. ):
  418. continue
  419. base_fn = (
  420. d[SchemaKind.mutable]
  421. if has_mutable
  422. else d[SchemaKind.inplace]
  423. if has_inplace
  424. else d[SchemaKind.out]
  425. if has_out
  426. else d[SchemaKind.functional]
  427. )
  428. # Note: [Mutable ops that cannot get an out variant]
  429. # We can only generate an out= variant if either:
  430. # - the original function has tensor-like returns (since we can convert them to out kwargs)
  431. # - or it's inplace (since we can convert `self` to an out kwarg)
  432. # There are only two functions that don't fit this criteria today though,
  433. # and they both look like they should be fixed to be out= variants,
  434. # so if feels safer to ban this schema all-together
  435. base_fn_valid = base_fn.func.kind() == SchemaKind.inplace or any(
  436. r.type.is_tensor_like() for r in base_fn.func.returns
  437. )
  438. # Note: [Loosen the assertion that all functional should have out variant]
  439. # By design all functional operators should have our variants. The needs_out check
  440. # is loosening this requirement, changing it to only generate out variant if there's
  441. # an `autogen` block in the native function, in the long run it should be removed.
  442. # FIXME: Remove this after figuring out CI job failures related to min, max, mean
  443. needs_out = any("out" in str(op_name) for op_name in base_fn.autogen)
  444. gets_out_variant = not has_out and base_fn_valid and needs_out
  445. if not has_out and not base_fn_valid:
  446. if (
  447. str(base_fn.func.name)
  448. not in MUTABLE_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
  449. and str(base_fn.func.name)
  450. not in FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT
  451. ):
  452. raise AssertionError(
  453. f"""Found an operator that we could not generate an out= variant for: {str(base_fn.func)}.
  454. This type of operators don't have tensor-like return, making it difficult to generate a proper out= variant. If
  455. out= variant is not needed, please add the function name into FUNCTIONAL_OPS_THAT_CANNOT_GET_AN_OUT_VARIANT list."""
  456. )
  457. # Generate an out= variant
  458. if gets_out_variant:
  459. fn, metadata = generate_function(base_fn, SchemaKind.out)
  460. d[SchemaKind.out] = fn
  461. BackendIndex.grow_index(indices, metadata)
  462. rs.append(fn)
  463. # Generate a functional variant, but only do it if the operator got an out= variant
  464. # (Functional variants are only useful if we can group up the variants,
  465. # which we can only do if they have an out= variant)
  466. if not has_functional and (has_out or gets_out_variant):
  467. fn, metadata = generate_function(base_fn, SchemaKind.functional)
  468. d[SchemaKind.functional] = fn
  469. BackendIndex.grow_index(indices, metadata)
  470. rs.append(fn)
  471. def return_str(rets: tuple[Return, ...], names: list[str]) -> str:
  472. if len(rets) != len(names):
  473. raise AssertionError(
  474. f"Returns and names length mismatch: {len(rets)} vs {len(names)}"
  475. )
  476. if len(rets) == 0:
  477. return ""
  478. elif len(rets) == 1:
  479. return f"return {names[0]};"
  480. else:
  481. return f"return {dispatcher.returns_type(rets).cpp_type()}({', '.join(names)});"
  482. # Given a function, and the name of a variable corresponding to the output of that function,
  483. # gather up all of the individual returns that are not aliased
  484. def gather_nonaliased_inner_rets(func: FunctionSchema, out_var: str) -> list[str]:
  485. aliased_rets = func.aliased_return_names()
  486. non_aliased_names = []
  487. is_out_var_a_tuple = len(func.returns) > 1
  488. for i, r in enumerate(aliased_rets):
  489. if r is None:
  490. non_aliased_names.append(
  491. f"std::get<{i}>({out_var})" if is_out_var_a_tuple else out_var
  492. )
  493. return non_aliased_names
  494. # Generates functional kernels in terms of their inplace.mutable counterparts.
  495. # We only do this for "generated" NativeFunctions
  496. @with_native_function
  497. def gen_composite_functional_kernel(g: NativeFunctionsGroup) -> str | None:
  498. # We should only be generating these for code-generated NativeFunctions
  499. if "generated" not in g.functional.tags:
  500. return None
  501. # And we always write the kernel for a generated op in terms of a non-generated op.
  502. if g.inplace is not None and "generated" not in g.inplace.tags:
  503. target_f = g.inplace
  504. elif g.mutable is not None and "generated" not in g.mutable.tags:
  505. target_f = g.mutable
  506. else:
  507. # We should be guaranteed to have a valid inplace/mutable variant to call into.
  508. # See Note: [Mutable Ops Not Using Functionalization]
  509. raise AssertionError(str(g.functional.func))
  510. sig = DispatcherSignature(g.functional.func)
  511. target_sig = DispatcherSignature(target_f.func)
  512. context: list[Binding | Expr] = []
  513. clone_mutable_inputs = []
  514. cloned_return_names = []
  515. # We can't just directly pass all of the arguments from the functional op into the mutating op.
  516. # We need to check for which inputs to the mutating operator are mutable,
  517. # and clone those inputs first.
  518. for a_curr, a_tgt in zip(
  519. dispatcher.jit_arguments(g.functional.func),
  520. dispatcher.jit_arguments(target_f.func),
  521. ):
  522. if a_tgt.annotation is not None and a_tgt.annotation.is_write:
  523. clone_mutable_inputs.append(
  524. f"auto {a_curr.name}_clone = clone_arg({a_curr.name});"
  525. )
  526. context.append(
  527. Expr(
  528. expr=f"{a_curr.name}_clone",
  529. type=dispatcher.argument_type(a_curr, binds=a_curr.name),
  530. )
  531. )
  532. # Invariant: mutable arguments on the inner mutable op are always returns on the functional op.
  533. cloned_return_names.append(f"{a_curr.name}_clone")
  534. else:
  535. context.append(dispatcher.argument(a_curr))
  536. exprs = ", ".join([e.expr for e in translate(context, target_sig.arguments())])
  537. out_name = "output"
  538. maybe_assign = f"auto {out_name} = " if len(target_f.func.returns) > 0 else ""
  539. inner_return_names = gather_nonaliased_inner_rets(target_f.func, out_name)
  540. ret_str = return_str(
  541. g.functional.func.returns, inner_return_names + cloned_return_names
  542. )
  543. clone_mutable_inputs_str = "\n".join(clone_mutable_inputs)
  544. return f"""
  545. {sig.defn(name=sig.name() + ("_symint" if g.out.func.has_symint() else ""))} {{
  546. {clone_mutable_inputs_str}
  547. {maybe_assign}at::_ops::{target_f.func.name.unambiguous_name()}::call({exprs});
  548. {ret_str}
  549. }}
  550. """
  551. # Generates out= kernels in terms of their functional counterparts.
  552. # We only do this for "generated" NativeFunctions
  553. @with_native_function
  554. def gen_composite_out_kernel(g: NativeFunctionsGroup) -> str | None:
  555. # We should only be generating these for code-generated NativeFunctions
  556. if "generated" not in g.out.tags:
  557. return None
  558. # And we always write the kernel for the out= op in terms of the functional.
  559. # Note that the functional op might have also been generated, but we don't have to
  560. # worry about cycles, because the generated functional kernels are always implemented
  561. # in terms of non-generated kernels (see gen_composite_functional_kernel).
  562. sig = DispatcherSignature(g.out.func)
  563. target_sig = DispatcherSignature(g.functional.func)
  564. exprs = ", ".join(
  565. [e.expr for e in translate(sig.arguments(), target_sig.arguments())]
  566. )
  567. copy_outs = []
  568. out_name = "tmp_output"
  569. for i, out_arg in enumerate(g.out.func.arguments.out):
  570. functional_return_name = (
  571. out_name
  572. if len(g.functional.func.returns) == 1
  573. else f"std::get<{i}>({out_name})"
  574. )
  575. copy_outs.append(
  576. f"""\
  577. resize_out_helper({out_arg.name}, {functional_return_name});
  578. copy_arg({out_arg.name}, {functional_return_name});"""
  579. )
  580. rets = []
  581. # For each return arg in the calling (out=) operator,
  582. # If it corresponds to an aliased input, return the input.
  583. # Otherwise, return the corresponding output from calling the functional operator.
  584. for i, ret_name in enumerate(g.out.func.aliased_return_names()):
  585. if ret_name is not None:
  586. rets.append(ret_name)
  587. else:
  588. functional_return_name = (
  589. out_name
  590. if len(g.functional.func.returns) == 1
  591. else f"std::get<{i}>({out_name})"
  592. )
  593. rets.append(functional_return_name)
  594. copy_outs_str = "\n".join(copy_outs)
  595. # Kernel name needs to follow the naming convention defined in `generate_function()`
  596. return f"""
  597. {sig.defn(name=g.out.func.name.unambiguous_name() + ("_symint" if g.out.func.has_symint() else ""))} {{
  598. auto {out_name} = at::_ops::{g.functional.func.name.unambiguous_name()}::call({exprs});
  599. {copy_outs_str}
  600. {return_str(g.out.func.returns, rets)}
  601. }}
  602. """