gen_backend_stubs.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. from __future__ import annotations
  2. import argparse
  3. import os
  4. import re
  5. from collections import Counter, defaultdict, namedtuple
  6. from pathlib import Path
  7. from typing import TYPE_CHECKING
  8. import yaml
  9. import torchgen.api.dispatcher as dispatcher
  10. import torchgen.dest as dest
  11. from torchgen.api.types import DispatcherSignature
  12. from torchgen.code_template import CodeTemplate
  13. from torchgen.context import native_function_manager
  14. from torchgen.gen import get_grouped_native_functions, parse_native_yaml
  15. from torchgen.model import (
  16. BackendIndex,
  17. BackendMetadata,
  18. DispatchKey,
  19. NativeFunction,
  20. NativeFunctionsGroup,
  21. OperatorName,
  22. )
  23. from torchgen.selective_build.selector import SelectiveBuilder
  24. from torchgen.utils import concatMap, context, FileManager, NamespaceHelper, Target
  25. from torchgen.yaml_utils import YamlLoader
  26. if TYPE_CHECKING:
  27. from collections.abc import Sequence
  28. # Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
  29. # Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping)
  30. ParsedExternalYaml = namedtuple(
  31. "ParsedExternalYaml",
  32. ["backend_key", "autograd_key", "class_name", "cpp_namespace", "backend_indices"],
  33. )
  34. def parse_backend_yaml(
  35. backend_yaml_path: str,
  36. grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
  37. backend_indices: dict[DispatchKey, BackendIndex],
  38. ) -> ParsedExternalYaml:
  39. native_functions_map: dict[OperatorName, NativeFunction] = {
  40. f.func.name: f
  41. for f in concatMap(
  42. lambda f: [f] if isinstance(f, NativeFunction) else list(f.functions()),
  43. grouped_native_functions,
  44. )
  45. }
  46. with open(backend_yaml_path) as f:
  47. yaml_values = yaml.load(f, Loader=YamlLoader)
  48. if not isinstance(yaml_values, dict):
  49. raise AssertionError(
  50. f"Expected yaml_values to be a dict, got {type(yaml_values)}"
  51. )
  52. valid_keys = [
  53. "backend",
  54. "class_name",
  55. "cpp_namespace",
  56. "extra_headers",
  57. "supported",
  58. "autograd",
  59. "full_codegen",
  60. "non_native",
  61. "ir_gen",
  62. "symint",
  63. ]
  64. backend = yaml_values.pop("backend", None)
  65. if backend is None:
  66. raise AssertionError('You must provide a value for "backend"')
  67. class_name = yaml_values.pop("class_name", None)
  68. cpp_namespace = yaml_values.pop("cpp_namespace", None)
  69. if cpp_namespace is None:
  70. raise AssertionError('You must provide a value for "cpp_namespace"')
  71. # Mostly just defaulting to false to stick with LazyTensor convention.
  72. use_out_as_primary = yaml_values.pop("use_out_as_primary", False)
  73. if not isinstance(use_out_as_primary, bool):
  74. raise AssertionError(
  75. f"You must provide either True or False for use_out_as_primary. Provided: {use_out_as_primary}"
  76. )
  77. use_device_guard = yaml_values.pop("device_guard", False)
  78. if not isinstance(use_device_guard, bool):
  79. raise AssertionError(
  80. f"You must provide either True or False for device_guard. Provided: {use_device_guard}"
  81. )
  82. supported = yaml_values.pop("supported", [])
  83. if supported is None:
  84. supported = [] # Allow an empty list of supported ops
  85. if not isinstance(supported, list):
  86. raise AssertionError(
  87. f'expected "supported" to be a list, but got: {supported} (of type {type(supported)})'
  88. )
  89. symint = yaml_values.pop("symint", [])
  90. if symint is None:
  91. symint = [] # Allow an empty list of symint ops
  92. if not isinstance(symint, list):
  93. raise AssertionError(
  94. f'expected "symint" to be a list, but got: {symint} (of type {type(symint)})'
  95. )
  96. symint_set = set(symint)
  97. supported_autograd = yaml_values.pop("autograd", [])
  98. if not isinstance(supported_autograd, list):
  99. raise AssertionError(
  100. f'expected "autograd" to be a list, but got: {supported_autograd}'
  101. )
  102. # full_codegen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
  103. full_codegen = yaml_values.pop("full_codegen", [])
  104. supported.extend(full_codegen)
  105. # non_native is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
  106. yaml_values.pop("non_native", {})
  107. # ir_gen is ignored by parse_backend_yaml, and re-parsed in gen_lazy_tensor.py
  108. yaml_values.pop("ir_gen", {})
  109. if len(yaml_values.keys()) != 0:
  110. raise AssertionError(
  111. f"{backend_yaml_path} contains unexpected keys: {', '.join(yaml_values.keys())}. "
  112. f"Only the following keys are supported: {', '.join(valid_keys)}"
  113. )
  114. def create_backend_index(
  115. backend_ops: list[str],
  116. symint_ops: set[str],
  117. dispatch_key: DispatchKey,
  118. *,
  119. use_out_as_primary: bool,
  120. use_device_guard: bool,
  121. ) -> BackendIndex:
  122. metadata: dict[OperatorName, BackendMetadata] = {}
  123. for op in backend_ops:
  124. op_name = OperatorName.parse(op)
  125. if op_name not in native_functions_map:
  126. raise AssertionError(f"Found an invalid operator name: {op_name}")
  127. # See Note [External Backends Follow Dispatcher API]
  128. kernel_name = dispatcher.name(native_functions_map[op_name].func)
  129. if op in symint_ops:
  130. kernel_name += "_symint"
  131. # TODO: allow structured external backends later.
  132. m = BackendMetadata(
  133. kernel=kernel_name, structured=False, cpp_namespace=cpp_namespace
  134. )
  135. metadata[op_name] = m
  136. return BackendIndex(
  137. dispatch_key=dispatch_key,
  138. use_out_as_primary=use_out_as_primary,
  139. external=True,
  140. device_guard=use_device_guard,
  141. index=metadata,
  142. )
  143. backend_key: DispatchKey | None = None
  144. if len(supported) > 0:
  145. with context(
  146. lambda: f'The provided value for "backend" must be a valid DispatchKey, but got {backend}.'
  147. ):
  148. backend_key = DispatchKey.parse(backend)
  149. backend_idx = create_backend_index(
  150. supported,
  151. symint_set,
  152. backend_key,
  153. use_out_as_primary=use_out_as_primary,
  154. use_device_guard=use_device_guard,
  155. )
  156. if backend_key in backend_indices:
  157. raise AssertionError(f"Duplicate backend key: {backend_key}")
  158. backend_indices[backend_key] = backend_idx
  159. autograd_key: DispatchKey | None = None
  160. if len(supported_autograd) > 0:
  161. with context(
  162. lambda: f'The "autograd" key was specified, which indicates that you would like to override \
  163. the behavior of autograd for some operators on your backend. However "Autograd{backend}" is not a valid DispatchKey.'
  164. ):
  165. autograd_key = DispatchKey.parse(f"Autograd{backend}")
  166. autograd_idx = create_backend_index(
  167. supported_autograd,
  168. symint_set,
  169. autograd_key,
  170. use_out_as_primary=use_out_as_primary,
  171. use_device_guard=use_device_guard,
  172. )
  173. if autograd_key in backend_indices:
  174. raise AssertionError(f"Duplicate autograd key: {autograd_key}")
  175. backend_indices[autograd_key] = autograd_idx
  176. for g in grouped_native_functions:
  177. if isinstance(g, NativeFunction):
  178. forward_kernels = (
  179. []
  180. if backend_key is None
  181. else [
  182. m
  183. for m in [backend_indices[backend_key].get_kernel(g)]
  184. if m is not None
  185. ]
  186. )
  187. backward_kernels = (
  188. []
  189. if autograd_key is None
  190. else [
  191. m
  192. for m in [backend_indices[autograd_key].get_kernel(g)]
  193. if m is not None
  194. ]
  195. )
  196. else:
  197. forward_kernels = (
  198. []
  199. if backend_key is None
  200. else [
  201. m
  202. for m in [
  203. backend_indices[backend_key].get_kernel(f)
  204. for f in g.functions()
  205. ]
  206. if m is not None
  207. ]
  208. )
  209. backward_kernels = (
  210. []
  211. if autograd_key is None
  212. else [
  213. m
  214. for m in [
  215. backend_indices[autograd_key].get_kernel(f)
  216. for f in g.functions()
  217. ]
  218. if m is not None
  219. ]
  220. )
  221. forward_kernels = [f for f in forward_kernels if f is not None]
  222. backward_kernels = [f for f in backward_kernels if f is not None]
  223. if not (len(forward_kernels) == 0 or len(backward_kernels) == 0):
  224. raise AssertionError(
  225. f"Currently, all variants of an op must either be registered to a backend key, "
  226. f"or to a backend's autograd key. They cannot be mix and matched. "
  227. f"If this is something you need, feel free to create an issue! "
  228. f'{forward_kernels[0].kernel} is listed under "supported", '
  229. f'but {backward_kernels[0].kernel} is listed under "autograd".'
  230. )
  231. return ParsedExternalYaml(
  232. backend_key, autograd_key, class_name, cpp_namespace, backend_indices
  233. )
  234. def error_on_missing_kernels(
  235. native_functions: Sequence[NativeFunction],
  236. backend_indices: dict[DispatchKey, BackendIndex],
  237. backend_key: DispatchKey,
  238. autograd_key: DispatchKey | None,
  239. class_name: str,
  240. kernel_defn_file_path: str,
  241. full_codegen: list[OperatorName] | None = None,
  242. ) -> None:
  243. try:
  244. with open(kernel_defn_file_path) as f:
  245. backend_defns = f.read()
  246. except OSError as e:
  247. raise AssertionError(
  248. f"Unable to read from the specified impl_path file: {kernel_defn_file_path}"
  249. ) from e
  250. if full_codegen is None:
  251. full_codegen = []
  252. indices = [backend_indices[backend_key].index] + (
  253. [] if autograd_key is None else [backend_indices[autograd_key].index]
  254. )
  255. # Quick mapping from each OperatorName used by the external backend
  256. # to its backend kernel name
  257. expected_backend_op_names: dict[OperatorName, str] = dict(
  258. list(
  259. concatMap(
  260. lambda index: [
  261. (op_name, metadata.kernel) for op_name, metadata in index.items()
  262. ],
  263. indices,
  264. )
  265. )
  266. )
  267. expected_backend_native_funcs: list[NativeFunction] = [
  268. f
  269. for f in native_functions
  270. if f.func.name in expected_backend_op_names and f.func.name not in full_codegen
  271. ]
  272. expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict(
  273. list
  274. )
  275. for native_f in expected_backend_native_funcs:
  276. expected_backend_kernel_name_counts[
  277. expected_backend_op_names[native_f.func.name]
  278. ].append(native_f)
  279. # This just looks for lines containing "foo(", and assumes that the kernel foo has been implemented.
  280. # It might cause false negatives (we won't catch all cases), but that's ok - if we catch a missing kernel
  281. # here, then we get a nicer error message. If we miss it, you get a linker error.
  282. kernel_defn_regex = rf"(.*){class_name}::\s*([\w\d]*)\("
  283. actual_backend_kernel_name_counts = Counter(
  284. # A bit unwieldy (this could probably be moved into regex),
  285. # but we don't want to include kernel names that come from function calls,
  286. # like "return torch_xla::XLANativeFunctions::empty_strided_symint(...)".
  287. # Easy check is to ignore any lines with colons before the class name.
  288. [
  289. y
  290. for (x, y) in re.findall(kernel_defn_regex, backend_defns)
  291. if not x.endswith(":")
  292. ]
  293. )
  294. missing_kernels_err_msg = ""
  295. for expected_name, funcs in expected_backend_kernel_name_counts.items():
  296. expected_overload_count = len(funcs)
  297. actual_overload_count = actual_backend_kernel_name_counts[expected_name]
  298. if expected_overload_count != actual_overload_count:
  299. def create_decl(f: NativeFunction) -> str:
  300. with native_function_manager(f):
  301. return DispatcherSignature.from_schema(f.func).decl()
  302. expected_schemas_str = "\n".join([create_decl(f) for f in funcs])
  303. missing_kernels_err_msg += f"""
  304. {class_name} is missing a kernel definition for {expected_name}. We found {actual_overload_count} kernel(s) with that name,
  305. but expected {expected_overload_count} kernel(s). The expected function schemas for the missing operator are:
  306. {expected_schemas_str}
  307. """
  308. if missing_kernels_err_msg != "":
  309. raise AssertionError(missing_kernels_err_msg)
  310. def main() -> None:
  311. parser = argparse.ArgumentParser(description="Generate backend stub files")
  312. parser.add_argument(
  313. "-s",
  314. "--source-yaml",
  315. "--source_yaml",
  316. help="path to source yaml file containing operator external definitions",
  317. )
  318. parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory")
  319. parser.add_argument(
  320. "--dry-run", "--dry_run", type=bool, default=False, help="output directory"
  321. )
  322. parser.add_argument(
  323. "--impl-path",
  324. "--impl_path",
  325. type=str,
  326. default=None,
  327. help="path to the source C++ file containing kernel definitions",
  328. )
  329. options = parser.parse_args()
  330. run(options.source_yaml, options.output_dir, options.dry_run, options.impl_path)
  331. def gen_dispatchkey_nativefunc_headers(
  332. fm: FileManager,
  333. class_name: str,
  334. cpp_namespace: str,
  335. backend_indices: dict[DispatchKey, BackendIndex],
  336. grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
  337. backend_dispatch_key: DispatchKey,
  338. autograd_dispatch_key: DispatchKey | None,
  339. backend_name: str = "",
  340. ) -> None:
  341. if class_name is None:
  342. raise AssertionError("class_name must not be None")
  343. generated_comment = (
  344. "Autogenerated file by gen_backend_stubs.py. Do not edit directly!"
  345. )
  346. # Convert to a set first to remove duplicate kernel names.
  347. # Backends are allowed to repeat kernel names; only generate the declaration once!
  348. # Sort for deterministic output.
  349. backend_declarations = sorted(
  350. set(
  351. concatMap(
  352. lambda f: dest.compute_native_function_declaration(
  353. f, backend_indices[backend_dispatch_key]
  354. ),
  355. grouped_native_functions,
  356. )
  357. )
  358. )
  359. autograd_declarations = sorted(
  360. set(
  361. concatMap(
  362. lambda f: []
  363. if autograd_dispatch_key is None
  364. else dest.compute_native_function_declaration(
  365. f, backend_indices[autograd_dispatch_key]
  366. ),
  367. grouped_native_functions,
  368. )
  369. )
  370. )
  371. ns_helper = NamespaceHelper(cpp_namespace)
  372. fm.write_with_template(
  373. f"{backend_dispatch_key}NativeFunctions.h",
  374. "DispatchKeyNativeFunctions.h",
  375. lambda: {
  376. "generated_comment": generated_comment,
  377. "namespace_prologue": ns_helper.prologue,
  378. "class_name": class_name,
  379. "namespace_epilogue": ns_helper.epilogue,
  380. "dispatch_declarations": backend_declarations + autograd_declarations,
  381. "BackendName": backend_name,
  382. "DispatchKey": backend_dispatch_key,
  383. },
  384. )
  385. def gen_dispatcher_registrations(
  386. fm: FileManager,
  387. output_dir: str,
  388. class_name: str,
  389. backend_indices: dict[DispatchKey, BackendIndex],
  390. grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
  391. backend_dispatch_key: DispatchKey,
  392. dispatch_key: DispatchKey,
  393. selector: SelectiveBuilder,
  394. # build_in_tree is true for lazy TS backend and affects include paths, not used for external backends
  395. build_in_tree: bool = False,
  396. per_operator_headers: bool = False,
  397. backend_name: str = "",
  398. eager_registration: bool = True,
  399. ) -> None:
  400. headers = [
  401. f"{output_dir}/{backend_dispatch_key}NativeFunctions.h",
  402. ]
  403. if build_in_tree:
  404. external_backend_headers_str = "\n".join(f"#include <{h}>" for h in headers)
  405. else:
  406. external_backend_headers_str = "\n".join(f'#include "{h}"' for h in headers)
  407. if class_name is None:
  408. raise AssertionError("class_name must not be None")
  409. backend_index = backend_indices[dispatch_key]
  410. dispatch_registrations_body = list(
  411. concatMap(
  412. dest.RegisterDispatchKey(
  413. backend_index,
  414. Target.REGISTRATION,
  415. selector,
  416. rocm=False,
  417. symint=True,
  418. class_method_name=f"{class_name}",
  419. skip_dispatcher_op_registration=False,
  420. ),
  421. grouped_native_functions,
  422. )
  423. )
  424. newline = "\n"
  425. ns_helper = NamespaceHelper(namespace_str="at")
  426. deferred_dispatch_registrations = ""
  427. static_init_dispatch_registrations = ""
  428. if eager_registration:
  429. static_template = CodeTemplate(
  430. """\
  431. TORCH_LIBRARY_IMPL(aten, $dispatch_key, m) {
  432. $dispatch_registrations_body
  433. }"""
  434. )
  435. static_init_dispatch_registrations = static_template.substitute(
  436. dispatch_key=dispatch_key,
  437. dispatch_registrations_body=dispatch_registrations_body,
  438. )
  439. else:
  440. deferred_template = CodeTemplate(
  441. """\
  442. TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions();
  443. TORCH_API void Register${backend_name}${dispatch_key}NativeFunctions() {
  444. static auto m = MAKE_TORCH_LIBRARY_IMPL(aten, $dispatch_key);
  445. $dispatch_registrations_body
  446. }"""
  447. )
  448. deferred_dispatch_registrations = deferred_template.substitute(
  449. backend_name=backend_name,
  450. dispatch_key=dispatch_key,
  451. dispatch_registrations_body=dispatch_registrations_body,
  452. )
  453. fm.write_with_template(
  454. f"Register{dispatch_key}.cpp",
  455. "RegisterDispatchKey.cpp",
  456. lambda: {
  457. "extra_cuda_headers": "",
  458. "external_backend_headers": external_backend_headers_str,
  459. "ops_headers": "#include <ATen/Functions.h>"
  460. if not per_operator_headers
  461. else "",
  462. "DispatchKey": dispatch_key,
  463. "dispatch_namespace": dispatch_key.lower(),
  464. "dispatch_headers": dest.gen_registration_headers(
  465. backend_index, per_operator_headers=per_operator_headers, rocm=False
  466. ),
  467. "dispatch_helpers": dest.gen_registration_helpers(backend_index),
  468. "dispatch_definitions": fm.substitute_with_template(
  469. "RegisterDispatchDefinitions.ini",
  470. lambda: {
  471. "ns_prologue": ns_helper.prologue,
  472. "ns_epilogue": ns_helper.epilogue,
  473. "static_init_dispatch_registrations": static_init_dispatch_registrations,
  474. "deferred_dispatch_registrations": deferred_dispatch_registrations,
  475. "dispatch_namespace": dispatch_key.lower(),
  476. "dispatch_namespaced_definitions": "",
  477. "dispatch_anonymous_definitions": list(
  478. concatMap(
  479. dest.RegisterDispatchKey(
  480. backend_index,
  481. Target.ANONYMOUS_DEFINITION,
  482. selector,
  483. rocm=False,
  484. symint=True,
  485. class_method_name=f"{class_name}",
  486. skip_dispatcher_op_registration=False,
  487. ),
  488. grouped_native_functions,
  489. )
  490. ),
  491. },
  492. ).split(newline),
  493. },
  494. )
  495. def run(
  496. source_yaml: str, output_dir: str, dry_run: bool, impl_path: str | None = None
  497. ) -> None:
  498. # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
  499. pytorch_root = Path(__file__).absolute().parent.parent
  500. template_dir = os.path.join(pytorch_root, "aten/src/ATen/templates")
  501. def make_file_manager(install_dir: str) -> FileManager:
  502. return FileManager(
  503. install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
  504. )
  505. fm = make_file_manager(output_dir)
  506. native_yaml_path = os.path.join(
  507. pytorch_root, "aten/src/ATen/native/native_functions.yaml"
  508. )
  509. tags_yaml_path = os.path.join(pytorch_root, "aten/src/ATen/native/tags.yaml")
  510. parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
  511. native_functions, backend_indices = (
  512. parsed_yaml.native_functions,
  513. parsed_yaml.backend_indices,
  514. )
  515. grouped_native_functions = get_grouped_native_functions(native_functions)
  516. parsed_backend_yaml = parse_backend_yaml(
  517. source_yaml, grouped_native_functions, backend_indices
  518. )
  519. backend_key = parsed_backend_yaml.backend_key
  520. autograd_key = parsed_backend_yaml.autograd_key
  521. cpp_namespace = parsed_backend_yaml.cpp_namespace
  522. class_name = parsed_backend_yaml.class_name
  523. backend_indices = parsed_backend_yaml.backend_indices
  524. selector = SelectiveBuilder.get_nop_selector()
  525. if backend_key is None:
  526. # This could be useful if a backend wants to quickly set up a noop yaml file but doesn't have any kernels ready yet.
  527. return
  528. if class_name is None:
  529. # class_name is an optional argument to backend yaml file.
  530. # if specified it allows an external backend to override
  531. # the name of the class that all generated kernel definitions live under.
  532. # if not specified, its value is given as native_function_class_name.
  533. class_name = backend_indices[backend_key].native_function_class_name()
  534. if class_name is None:
  535. raise AssertionError("class_name must not be None")
  536. if impl_path is not None:
  537. error_on_missing_kernels(
  538. native_functions,
  539. backend_indices,
  540. backend_key,
  541. autograd_key,
  542. class_name,
  543. impl_path,
  544. )
  545. gen_dispatchkey_nativefunc_headers(
  546. fm,
  547. class_name,
  548. cpp_namespace,
  549. backend_indices,
  550. grouped_native_functions,
  551. backend_key,
  552. autograd_key,
  553. )
  554. for dispatch_key in (
  555. [backend_key] if autograd_key is None else [backend_key, autograd_key]
  556. ):
  557. gen_dispatcher_registrations(
  558. fm,
  559. output_dir,
  560. class_name,
  561. backend_indices,
  562. grouped_native_functions,
  563. backend_key,
  564. dispatch_key,
  565. selector,
  566. )
  567. if __name__ == "__main__":
  568. main()