gen_aoti_c_shim.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774
  1. from __future__ import annotations
  2. import difflib
  3. import os
  4. import textwrap
  5. from dataclasses import dataclass
  6. from typing import TYPE_CHECKING
  7. from torchgen.aoti.fallback_ops import aten_shimified_ops, inductor_fallback_ops
  8. from torchgen.api.types import DispatcherSignature
  9. from torchgen.api.types.signatures import CppSignature, CppSignatureGroup
  10. from torchgen.context import method_with_native_function
  11. from torchgen.model import (
  12. Argument,
  13. BackendIndex,
  14. BaseTy,
  15. BaseType,
  16. DispatchKey,
  17. FunctionSchema,
  18. is_cuda_dispatch_key,
  19. ListType,
  20. NativeFunction,
  21. NativeFunctionsGroup,
  22. OperatorName,
  23. OptionalType,
  24. Type,
  25. Variant,
  26. )
  27. from torchgen.utils import FileManager, mapMaybe
  28. if TYPE_CHECKING:
  29. from collections.abc import Sequence
  30. base_type_to_c_type = {
  31. BaseTy.Tensor: "AtenTensorHandle",
  32. BaseTy.bool: "int32_t", # Use int to pass bool
  33. BaseTy.int: "int64_t",
  34. BaseTy.SymInt: "int64_t", # Inductor-generated code won't see a SymInt
  35. BaseTy.Scalar: "double", # Use double to pass both integer and floating point
  36. BaseTy.float: "double", # TODO: how about other floating point types?
  37. BaseTy.str: "const char*",
  38. BaseTy.DeviceIndex: "int32_t",
  39. BaseTy.Layout: "int32_t", # Represent enum as int
  40. BaseTy.MemoryFormat: "int32_t", # Represent enum as int
  41. BaseTy.ScalarType: "int32_t", # Represent enum as int
  42. BaseTy.Generator: "AtenGeneratorHandle",
  43. }
  44. base_type_to_aten_type = {
  45. BaseTy.Tensor: "at::Tensor",
  46. BaseTy.bool: "bool",
  47. BaseTy.int: "int64_t",
  48. BaseTy.SymInt: "c10::SymInt",
  49. BaseTy.Scalar: "c10::Scalar",
  50. BaseTy.float: "double",
  51. BaseTy.str: "::std::string_view",
  52. BaseTy.DeviceIndex: "c10::DeviceIndex",
  53. BaseTy.Layout: "c10::Layout",
  54. BaseTy.MemoryFormat: "c10::MemoryFormat",
  55. BaseTy.ScalarType: "c10::ScalarType",
  56. BaseTy.Generator: "at::Generator",
  57. }
  58. base_type_to_callsite_expr = {
  59. BaseTy.Tensor: "resolve_tensor_dispatch_flags",
  60. BaseTy.bool: "",
  61. BaseTy.int: "",
  62. BaseTy.SymInt: "",
  63. BaseTy.Scalar: "",
  64. BaseTy.float: "",
  65. BaseTy.str: "",
  66. BaseTy.DeviceIndex: "static_cast<c10::DeviceIndex>",
  67. BaseTy.Layout: "static_cast<c10::Layout>",
  68. BaseTy.MemoryFormat: "static_cast<c10::MemoryFormat>",
  69. BaseTy.ScalarType: "static_cast<c10::ScalarType>",
  70. BaseTy.Generator: "*generator_handle_to_generator_pointer",
  71. }
  72. # convert args to C types, names in declarations, and expressions in function bodies
  73. def convert_arg_type_and_name(
  74. typ: Type,
  75. name: str,
  76. is_write: bool = False,
  77. ) -> tuple[list[str], list[str], list[str], list[str]]:
  78. if isinstance(typ, BaseType):
  79. if typ.name in base_type_to_c_type:
  80. if typ.name == BaseTy.Tensor and is_write:
  81. # For output tensors, our normal call to resolve_tensor_dispatch_flags
  82. # results in an rvalue tensor, which can't be passed to at::Tensor&.
  83. # Override this case specifically.
  84. callsite_expr = [f"*tensor_handle_to_tensor_pointer({name})"]
  85. else:
  86. callsite_expr = [
  87. f"{base_type_to_callsite_expr[typ.name]}({name})"
  88. if base_type_to_callsite_expr[typ.name]
  89. else name
  90. ]
  91. return (
  92. [base_type_to_c_type[typ.name]],
  93. [name],
  94. [base_type_to_aten_type[typ.name]],
  95. callsite_expr,
  96. )
  97. elif typ.name == BaseTy.Device:
  98. return (
  99. ["int32_t", "int32_t"],
  100. [name, name + "_index_"],
  101. ["c10::Device"],
  102. [
  103. f"c10::Device(static_cast<c10::DeviceType>({name}), static_cast<c10::DeviceIndex>({name}_index_))"
  104. ],
  105. )
  106. else:
  107. # TODO: BaseTy.Dimname, etc.
  108. raise NotImplementedError(f"TODO: add support for arg type {repr(typ)}")
  109. elif isinstance(typ, OptionalType):
  110. c_types, names, aten_types, callsite_exprs = convert_arg_type_and_name(
  111. typ.elem, name
  112. )
  113. j = 0 # index for names
  114. new_aten_types = []
  115. new_callsite_exprs = []
  116. for aten_type in aten_types:
  117. # Use pointer to denote optional type
  118. c_types[j] = c_types[j] + "*"
  119. if aten_type.startswith("c10::ArrayRef<"):
  120. # ArrayRef is passed as pointer + size, but no need to add "*" to the size argument
  121. new_aten_types.append(f"::std::optional<{aten_type}>")
  122. base_type = aten_type[len("c10::ArrayRef<") : -1]
  123. new_callsite_exprs.append(
  124. f"pointer_to_optional_list<{base_type}>({names[j]}, {names[j + 1]})"
  125. )
  126. j += 2
  127. elif aten_type == "c10::Device":
  128. # Device is passed as device_type + device_index
  129. new_aten_types.append("::std::optional<c10::Device>")
  130. new_callsite_exprs.append(
  131. f"pointer_to_optional_device({names[j]}, {names[j + 1]})"
  132. )
  133. j += 2
  134. elif aten_type == "at::Tensor":
  135. new_aten_types.append(f"::std::optional<{aten_type}>")
  136. new_callsite_exprs.append(f"resolve_tensor_dispatch_flags({names[j]})")
  137. j += 1
  138. else:
  139. new_aten_types.append(f"::std::optional<{aten_type}>")
  140. new_callsite_exprs.append(
  141. f"pointer_to_optional<{aten_type}>({names[j]})"
  142. )
  143. j += 1
  144. return (
  145. c_types,
  146. names,
  147. new_aten_types,
  148. new_callsite_exprs,
  149. )
  150. elif isinstance(typ, ListType):
  151. # Need to explicitly pass the list as pointer + length
  152. c_types, names, aten_types, _ = convert_arg_type_and_name(typ.elem, name)
  153. if len(c_types) != 1:
  154. raise AssertionError(f"ListType with unsupported element type {repr(typ)}")
  155. # The list content should never be modified
  156. c_types[0] = f"const {c_types[0]}*"
  157. c_types.append("int64_t")
  158. name = names[0]
  159. names.append(name + "_len_")
  160. atype = aten_types[0]
  161. callsite_exprs = []
  162. if atype == "bool":
  163. # no converter from std::vector<bool> to c10::ArrayRef<bool>
  164. # construct std::array<bool, N> instead
  165. if typ.size is None:
  166. raise AssertionError("bool ListType must have a size")
  167. callsite_exprs.append(f"pointer_to_list<{typ.size}>({name})")
  168. elif atype == "at::Tensor" and not is_write:
  169. callsite_exprs.append(
  170. f"resolve_tensor_list_dispatch_flags({name}, {name}_len_)"
  171. )
  172. elif atype == "::std::optional<at::Tensor>":
  173. # convert from std::vector<::std::optional<at::Tensor>> to c10::List<::std::optional<at::Tensor>>
  174. callsite_exprs.append(
  175. f"c10::List<{atype}>(c10::ArrayRef<{atype}>(resolve_tensor_list_dispatch_flags({name}, {name}_len_)))"
  176. )
  177. else:
  178. callsite_exprs.append(f"pointer_to_list<{atype}>({name}, {name}_len_)")
  179. aten_types = [f"c10::ArrayRef<{t}>" for t in aten_types]
  180. return (
  181. c_types,
  182. names,
  183. aten_types,
  184. callsite_exprs,
  185. )
  186. raise NotImplementedError(f"Argument type {repr(typ)} not supported!")
  187. def zip_type_and_name(types: list[str], names: list[str]) -> list[str]:
  188. return [typ + " " + name for typ, name in zip(types, names)]
  189. # Generate argument declarations and callsite expressions
  190. def gen_arguments(
  191. flat_arguments: Sequence[Argument], skipped_args: set[str]
  192. ) -> tuple[list[str], list[str]]:
  193. types: list[str] = []
  194. new_names: list[str] = []
  195. callsite_exprs: list[str] = []
  196. for arg in flat_arguments:
  197. if arg.name in skipped_args:
  198. callsite_exprs.append("std::nullopt")
  199. continue
  200. new_types, names, _, new_callsite_exprs = convert_arg_type_and_name(
  201. arg.type, arg.name, arg.is_write
  202. )
  203. types.extend(new_types)
  204. new_names.extend(names)
  205. callsite_exprs.extend(new_callsite_exprs)
  206. return zip_type_and_name(types, new_names), callsite_exprs
  207. # Return values are passed out as pointer arguments because all the C shim functions
  208. # are expected to return AOTITorchError.
  209. # Generate returns as declarations and callsite expressions
  210. def gen_returns(schema: FunctionSchema) -> tuple[list[str], list[str]]:
  211. types = []
  212. names = []
  213. for idx, ret in enumerate(schema.returns):
  214. names.append(f"ret{idx}")
  215. if isinstance(ret.type, BaseType) and ret.type.name in base_type_to_c_type:
  216. types.append(base_type_to_c_type[ret.type.name] + "*")
  217. else:
  218. raise NotImplementedError(
  219. f"TODO: add support for return type {repr(ret.type)}"
  220. )
  221. def convert_return(typ: BaseType, val: str) -> str:
  222. if typ.name == BaseTy.Tensor:
  223. return f"new_tensor_handle(std::move({val}))"
  224. elif typ.name == BaseTy.SymInt:
  225. return f"{val}.expect_int()"
  226. elif typ.name == BaseTy.Scalar:
  227. return f"{val}.toDouble()"
  228. else:
  229. return val
  230. ret_pointer_can_be_null = False
  231. unambiguous_name = schema.name.unambiguous_name()
  232. for name in (
  233. "_functional_sym_constrain_range",
  234. "_scaled_dot_product_cudnn_attention",
  235. "_scaled_dot_product_efficient_attention_backward",
  236. "_scaled_dot_product_efficient_attention",
  237. "_scaled_dot_product_flash_attention",
  238. "_scaled_dot_product_fused_attention_overrideable",
  239. "_thhn_fused_lstm_cell_backward_impl",
  240. "convolution_backward",
  241. "grid_sampler_2d_backward",
  242. "grid_sampler_3d_backward",
  243. "linear_backward",
  244. ):
  245. if name in unambiguous_name:
  246. ret_pointer_can_be_null = True
  247. break
  248. callsite_exprs: list[str] = []
  249. for idx, ret in enumerate(schema.returns):
  250. tmp = "tmp_result" if len(names) == 1 else f"std::get<{idx}>(tmp_result)"
  251. if not isinstance(ret.type, BaseType):
  252. raise AssertionError(f"Expected BaseType for return, got {type(ret.type)}")
  253. rval = convert_return(ret.type, tmp)
  254. if ret_pointer_can_be_null:
  255. callsite_exprs.append(f"if ({names[idx]}) {{ *{names[idx]} = {rval}; }}")
  256. else:
  257. callsite_exprs.append(f"*{names[idx]} = {rval};")
  258. return zip_type_and_name(types, names), callsite_exprs
  259. # gen.py generates header first and then src, so caching the result here to avoid duplicate work
  260. declaration_definition_cache: dict[tuple[str, str, str], tuple[str, str]] = {}
  261. def gen_declaration_and_definition(
  262. schema: FunctionSchema,
  263. device: str,
  264. backend_call: str,
  265. version_info: dict[str, list[str]],
  266. ) -> tuple[str, str]:
  267. base_name = schema.name.unambiguous_name()
  268. global declaration_definition_cache
  269. if (base_name, device, backend_call) in declaration_definition_cache:
  270. return declaration_definition_cache[(base_name, device, backend_call)]
  271. # Check the validity of version_info. The format should look like
  272. # {"v2" : ["new_arg1"], "v3": ["new_arg2, new_arg3"]}.
  273. indexed_version_info: dict[int, list[str]] = {1: []}
  274. for ver_str, new_args in sorted(version_info.items()):
  275. if not ver_str.startswith("v"):
  276. raise AssertionError(
  277. f"Version number for {base_name} is {ver_str}, not starting with 'v'"
  278. )
  279. try:
  280. ver_id = int(ver_str[1:])
  281. except ValueError as e:
  282. raise AssertionError(
  283. f"Version number for {base_name} is {ver_str}, not a valid integer after 'v'"
  284. ) from e
  285. if ver_id in indexed_version_info:
  286. raise AssertionError(f"{ver_str} for {base_name} has already been defined")
  287. indexed_version_info[ver_id] = new_args
  288. declarations: list[str] = []
  289. definitions: list[str] = []
  290. skipped_args: set[str] = set()
  291. for ver_id, new_args in sorted(indexed_version_info.items(), reverse=True):
  292. # Iterate in the reverse order, so the latest version of an op will get generated first
  293. # with all the arguments included, while a set of to-be-trimmed args is carried down
  294. # to generate earlier version of the op.
  295. func_name = base_name if ver_id == 1 else f"{base_name}_v{ver_id}"
  296. if schema.is_out_fn():
  297. # out_variant has out arguments in the front, and it's ok to ignore return values
  298. # because C shim functions only return AOTITorchError
  299. args, callsite_exprs = gen_arguments(
  300. [*schema.arguments.out, *schema.arguments.flat_non_out], skipped_args
  301. )
  302. ret_assignments: list[str] = []
  303. else:
  304. args, callsite_exprs = gen_arguments(
  305. schema.arguments.flat_all, skipped_args
  306. )
  307. # ignore return values for inplace ops
  308. ret_declarations, ret_assignments = (
  309. ([], []) if schema.name.name.inplace else gen_returns(schema)
  310. )
  311. args.extend(ret_declarations)
  312. declaration = textwrap.dedent(
  313. f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"
  314. )
  315. tmp_result = "auto tmp_result = " if ret_assignments else ""
  316. indent = "\t\t"
  317. ret_assignments_str = (
  318. "\n".join(indent + r for r in ret_assignments) if ret_assignments else ""
  319. )
  320. definition = (
  321. textwrap.dedent(f"""
  322. {declaration} {{
  323. AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({{
  324. {tmp_result}{backend_call}(
  325. {", ".join(callsite_exprs)}
  326. );
  327. """)
  328. + ret_assignments_str
  329. + textwrap.dedent("""
  330. });
  331. }
  332. """)
  333. )
  334. skipped_args.update(new_args)
  335. declarations.append(f"AOTI_TORCH_EXPORT {declaration};")
  336. definitions.append(definition)
  337. declaration_definition_cache[(base_name, device, backend_call)] = (
  338. "\n".join(declarations),
  339. "\n".join(definitions),
  340. )
  341. return declaration_definition_cache[(base_name, device, backend_call)]
  342. def gen_static_dispatch_backend_call_signature(
  343. sig: CppSignature | DispatcherSignature,
  344. f: NativeFunction,
  345. ) -> CppSignature:
  346. sig = DispatcherSignature.from_schema(f.func)
  347. cpp_sigs = CppSignatureGroup.from_native_function(
  348. f, method=False, fallback_binding=False
  349. )
  350. if sig.symint and f.func.has_symint():
  351. cpp_sig = cpp_sigs.symint_signature
  352. else:
  353. cpp_sig = cpp_sigs.signature
  354. if cpp_sig is None:
  355. raise AssertionError(f"No cpp signature found for {f.func.name}")
  356. return cpp_sig
  357. def gen_static_dispatch_backend_call(
  358. f: NativeFunction,
  359. backend_index: BackendIndex | None = None,
  360. ) -> str:
  361. sig = DispatcherSignature.from_schema(f.func)
  362. cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
  363. if backend_index is None:
  364. # Check if this is a symint function and if the function only has method variants
  365. if sig.symint and f.func.has_symint():
  366. has_function_variant = Variant.function in f.variants
  367. if not has_function_variant:
  368. # Functions with both function and method variants can use the at::{*}_symint version
  369. # (e.g., narrow -> at::narrow_symint), BUT
  370. # Method-only functions with symint parameters should use at::symint:: namespace
  371. # Remove the _symint suffix since at::symint:: namespace uses the base name
  372. # (e.g., new_empty -> at::symint::new_empty<c10::SymInt>)
  373. base_name = cpp_sig.name()
  374. base_name = base_name.removesuffix("_symint") # Remove "_symint" suffix
  375. return f"at::symint::{base_name}<c10::SymInt>"
  376. return f"at::{cpp_sig.name()}"
  377. else:
  378. return f"at::{backend_index.dispatch_key.lower()}::{cpp_sig.name()}"
  379. def get_backend_index_for_aoti(
  380. func: NativeFunction,
  381. func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
  382. dispatch_key: DispatchKey | None,
  383. backend_indices: dict[DispatchKey, BackendIndex],
  384. extend_aoti_c_shim: bool,
  385. ) -> BackendIndex | None:
  386. backend_index = None
  387. if dispatch_key is None:
  388. return backend_index
  389. if backend_indices[dispatch_key].has_kernel(func) or (
  390. func.structured_delegate is not None
  391. and func.structured_delegate in func_group_mapping
  392. and backend_indices[dispatch_key].has_kernel(
  393. func_group_mapping[func.structured_delegate]
  394. )
  395. ):
  396. backend_index = backend_indices[dispatch_key]
  397. else:
  398. # for the extend out-of-tree kernels, we don't need to
  399. # duplicatly create C shim wrappers for other dispatch keys
  400. if extend_aoti_c_shim:
  401. return backend_index
  402. elif backend_indices[DispatchKey.CompositeExplicitAutograd].has_kernel(func):
  403. # We need to create C shim wrappers for CompositeExplicitAutograd kernels
  404. backend_index = backend_indices[DispatchKey.CompositeExplicitAutograd]
  405. elif backend_indices[
  406. DispatchKey.CompositeExplicitAutogradNonFunctional
  407. ].has_kernel(func):
  408. # We need to create C shim wrappers for CompositeExplicitAutogradNonFunctional kernels
  409. backend_index = backend_indices[
  410. DispatchKey.CompositeExplicitAutogradNonFunctional
  411. ]
  412. elif backend_indices[DispatchKey.CompositeImplicitAutograd].has_kernel(func):
  413. backend_index = backend_indices[DispatchKey.CompositeImplicitAutograd]
  414. return backend_index
  415. def get_header_for_aoti(
  416. func: NativeFunction,
  417. func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
  418. dispatch_key: DispatchKey | None,
  419. backend_indices: dict[DispatchKey, BackendIndex],
  420. extend_aoti_c_shim: bool,
  421. ) -> str | None:
  422. backend_index = get_backend_index_for_aoti(
  423. func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
  424. )
  425. if backend_index is None:
  426. if dispatch_key is None:
  427. return f"#include <ATen/ops/{func.root_name}.h>"
  428. return None
  429. return f"#include <ATen/ops/{func.root_name}_{backend_index.dispatch_key.lower()}_dispatch.h>"
  430. def get_fallback_op_name(func: NativeFunction) -> str:
  431. return (
  432. f"{func.namespace}.{func.func.name.name}.{func.func.name.overload_name}"
  433. if func.func.name.overload_name
  434. else f"{func.namespace}.{func.func.name.name}.default"
  435. )
  436. def gen_c_shim(
  437. func: NativeFunction,
  438. version_info: dict[str, list[str]],
  439. func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
  440. dispatch_key: DispatchKey | None,
  441. backend_indices: dict[DispatchKey, BackendIndex],
  442. header: bool,
  443. extend_aoti_c_shim: bool,
  444. ) -> str | None:
  445. backend_index = get_backend_index_for_aoti(
  446. func, func_group_mapping, dispatch_key, backend_indices, extend_aoti_c_shim
  447. )
  448. if backend_index is None and dispatch_key is not None:
  449. return None
  450. schema = func.func
  451. device = "aten" if dispatch_key is None else dispatch_key.lower()
  452. backend_call = gen_static_dispatch_backend_call(
  453. func,
  454. backend_index,
  455. )
  456. try:
  457. if header:
  458. declaration, _ = gen_declaration_and_definition(
  459. schema, device, backend_call, version_info
  460. )
  461. return declaration
  462. else:
  463. _, definition = gen_declaration_and_definition(
  464. schema, device, backend_call, version_info
  465. )
  466. return definition
  467. except NotImplementedError:
  468. return None
  469. @dataclass(frozen=True)
  470. class ShimGenerator:
  471. inductor_fallback_ops: dict[str, dict[str, list[str]]]
  472. func_group_mapping: dict[OperatorName, NativeFunctionsGroup]
  473. dispatch_key: DispatchKey | None
  474. backend_indices: dict[DispatchKey, BackendIndex]
  475. header: bool # True to generate .h and False to generate .cpp
  476. extend_aoti_c_shim: bool
  477. @method_with_native_function
  478. def __call__(
  479. self,
  480. func: NativeFunction,
  481. ) -> str | None:
  482. version_info = self.inductor_fallback_ops[get_fallback_op_name(func)]
  483. result = gen_c_shim(
  484. func,
  485. version_info,
  486. self.func_group_mapping,
  487. self.dispatch_key,
  488. self.backend_indices,
  489. self.header,
  490. self.extend_aoti_c_shim,
  491. )
  492. return result
  493. def gen_aoti_c_shim(
  494. native_functions: Sequence[NativeFunction],
  495. inductor_fallback_ops: dict[str, dict[str, list[str]]],
  496. func_group_mapping: dict[OperatorName, NativeFunctionsGroup],
  497. dispatch_key: DispatchKey | None,
  498. backend_indices: dict[DispatchKey, BackendIndex],
  499. header: bool,
  500. extend_aoti_c_shim: bool,
  501. includes: str = "",
  502. ) -> str:
  503. body = "\n".join(
  504. list(
  505. mapMaybe(
  506. ShimGenerator(
  507. inductor_fallback_ops,
  508. func_group_mapping,
  509. dispatch_key,
  510. backend_indices,
  511. header,
  512. extend_aoti_c_shim,
  513. ),
  514. native_functions,
  515. )
  516. )
  517. )
  518. device = "aten" if dispatch_key is None else dispatch_key.lower()
  519. include_device_functions = (
  520. "#include <ATen/Functions.h>"
  521. if dispatch_key is None
  522. else f"#include <ATen/{str(dispatch_key)}Functions.h>"
  523. )
  524. aten_warning = (
  525. (
  526. "\n\n// This file corresponds to the aten_shimified_ops list in torchgen/aoti/fallback_ops.py\n"
  527. )
  528. if dispatch_key is None
  529. else ""
  530. )
  531. warning = """
  532. // WARNING: THIS FILE IS AUTOGENERATED BY torchgen. DO NOT MODIFY BY HAND.
  533. // See https://github.com/pytorch/pytorch/blob/7e86a7c0155295539996e0cf422883571126073e/torchgen/gen.py#L2424-L2436 for details"""
  534. if header:
  535. return (
  536. warning
  537. + aten_warning
  538. + textwrap.dedent("""
  539. #pragma once
  540. #include <torch/csrc/inductor/aoti_torch/c/shim.h>
  541. #ifdef __cplusplus
  542. extern "C" {
  543. #endif
  544. """)
  545. + body
  546. + textwrap.dedent("""
  547. #ifdef __cplusplus
  548. } // extern "C"
  549. #endif
  550. """)
  551. )
  552. else:
  553. return (
  554. warning
  555. + aten_warning
  556. + textwrap.dedent(f"""
  557. #include <torch/csrc/inductor/aoti_torch/generated/{"extend/" if extend_aoti_c_shim else ""}c_shim_{device}.h>
  558. #include <torch/csrc/inductor/aoti_torch/utils.h>
  559. #ifndef AT_PER_OPERATOR_HEADERS
  560. {include_device_functions}
  561. #include <ATen/CompositeExplicitAutogradFunctions.h>
  562. #include <ATen/CompositeExplicitAutogradNonFunctionalFunctions.h>
  563. #include <ATen/CompositeImplicitAutogradFunctions.h>
  564. #else
  565. """)
  566. + includes
  567. + textwrap.dedent("""
  568. #endif // AT_PER_OPERATOR_HEADERS
  569. using namespace torch::aot_inductor;
  570. """)
  571. + body
  572. )
  573. def gen_aoti_c_shim_files(
  574. aoti_fm: FileManager,
  575. aoti_backends: set[DispatchKey | None],
  576. native_functions: Sequence[NativeFunction],
  577. backend_indices: dict[DispatchKey, BackendIndex],
  578. structured_native_functions: Sequence[NativeFunctionsGroup],
  579. extra_cuda_headers: str,
  580. extend_aoti_c_shim: bool,
  581. update_aoti_c_shim: bool,
  582. ) -> None:
  583. structured_func_group_dict = {}
  584. for func_group in structured_native_functions:
  585. for func in func_group.functions():
  586. if func.structured_delegate is not None:
  587. structured_func_group_dict[func.structured_delegate] = func_group
  588. break
  589. for dispatch_key in aoti_backends:
  590. # Use aten_shimified_ops for the aten backend, inductor_fallback_ops for others
  591. fallback_ops_dict = (
  592. aten_shimified_ops if dispatch_key is None else inductor_fallback_ops
  593. )
  594. fallbacks = {}
  595. for func in native_functions:
  596. op_name = get_fallback_op_name(func)
  597. if op_name in fallback_ops_dict:
  598. fallbacks[op_name] = func
  599. fallback_native_functions = tuple(
  600. value for _, value in sorted(fallbacks.items())
  601. )
  602. # Use "aten" as the device name when dispatch_key is Generic
  603. device_name = "aten" if dispatch_key is None else dispatch_key.lower()
  604. # header files were checked in for ABI-compatibility checking
  605. header_file_name = f"c_shim_{device_name}.h"
  606. new_header = gen_aoti_c_shim(
  607. fallback_native_functions,
  608. fallback_ops_dict,
  609. structured_func_group_dict,
  610. dispatch_key,
  611. backend_indices,
  612. header=True,
  613. extend_aoti_c_shim=extend_aoti_c_shim,
  614. includes="",
  615. )
  616. if update_aoti_c_shim:
  617. aoti_fm.write(
  618. header_file_name,
  619. lambda: new_header,
  620. )
  621. else:
  622. try:
  623. with open(
  624. os.path.join(aoti_fm.install_dir, header_file_name)
  625. ) as old_file:
  626. old_header = old_file.read()
  627. if old_header != new_header:
  628. diff = "\n".join(
  629. difflib.unified_diff(
  630. old_header.splitlines(),
  631. new_header.splitlines(),
  632. fromfile="expected",
  633. tofile="actual",
  634. lineterm="",
  635. )
  636. )
  637. raise RuntimeError(f"""
  638. The generated AOTInductor C shim header files have unexpectedly changed. This
  639. indicates an AOTInductor fallback operator ABI backward compatibility breakage!!!
  640. Only in a limited number of situations, this is allowed:
  641. 1. You added a fallback op to the inductor_fallback_ops list in torchgen/aoti/fallback_ops.py.
  642. If that's the case, run `python torchgen/gen.py --update-aoti-c-shim` to add a new entry to
  643. existing C shim header files.
  644. 2. You added a new default argument to an existing fallback op. This is clearly a BC breaking
  645. change in the AOTInductor land. You need to annotate the new default argument in
  646. torchgen/aoti/fallback_ops.py, and then run `python torchgen/gen.py --update-aoti-c-shim` to
  647. update the C shim header files by creating different versions of the fallback op. See
  648. https://github.com/pytorch/pytorch/pull/154848 as an example.
  649. {diff}
  650. """)
  651. except FileNotFoundError:
  652. print(
  653. f"{os.path.join(aoti_fm.install_dir, header_file_name)} not found"
  654. )
  655. # cpp files are always generated on-the-fly
  656. def headers_for_aoti() -> str:
  657. headers = []
  658. for func in fallback_native_functions:
  659. header = get_header_for_aoti(
  660. func,
  661. structured_func_group_dict,
  662. dispatch_key,
  663. backend_indices,
  664. extend_aoti_c_shim=extend_aoti_c_shim,
  665. )
  666. if header is not None:
  667. headers.append(header)
  668. return "\n".join(sorted(set(headers)))
  669. extra_headers = (
  670. extra_cuda_headers
  671. if dispatch_key is not None and is_cuda_dispatch_key(dispatch_key)
  672. else ""
  673. )
  674. aoti_fm.write(
  675. f"c_shim_{device_name}.cpp",
  676. lambda: gen_aoti_c_shim(
  677. fallback_native_functions,
  678. fallback_ops_dict,
  679. structured_func_group_dict,
  680. dispatch_key,
  681. backend_indices,
  682. header=False,
  683. extend_aoti_c_shim=extend_aoti_c_shim,
  684. includes=headers_for_aoti() + "\n" + extra_headers,
  685. ),
  686. )