gen_lazy_tensor.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. from __future__ import annotations
  2. import argparse
  3. import os
  4. from collections import namedtuple
  5. from pathlib import Path
  6. from typing import Any, TYPE_CHECKING
  7. import yaml
  8. import torchgen.dest as dest
  9. from torchgen.api.lazy import setValueT
  10. from torchgen.api.types import BaseCppType
  11. from torchgen.dest.lazy_ir import GenLazyIR, GenLazyNativeFuncDefinition, GenTSLazyIR
  12. from torchgen.gen import get_grouped_native_functions, parse_native_yaml
  13. from torchgen.gen_backend_stubs import (
  14. error_on_missing_kernels,
  15. gen_dispatcher_registrations,
  16. gen_dispatchkey_nativefunc_headers,
  17. parse_backend_yaml,
  18. )
  19. from torchgen.model import NativeFunction, NativeFunctionsGroup, OperatorName
  20. from torchgen.selective_build.selector import SelectiveBuilder
  21. from torchgen.utils import FileManager, NamespaceHelper
  22. from torchgen.yaml_utils import YamlLoader
  23. if TYPE_CHECKING:
  24. from collections.abc import Callable, Iterable, Iterator, Sequence
  25. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  26. #
  27. # Lazy Tensor Codegen
  28. #
  29. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  30. # Overview
  31. # ~~~~~~~~
  32. #
  33. # This codegen script builds on existing data models and helpers used
  34. # by all ATen backends, and adds new functionality specific to lazy
  35. # tensor backends.
  36. #
  37. # Inputs:
  38. # - <backend>_native_functions.yaml: controls which operators are
  39. # supported by the backend.
  40. #
  41. # Outputs:
  42. # (for all backends)
  43. # <DispatchKey>Ir.h defines Lazy IR classes to be constructed during tracing
  44. # - opt-in: also generate 'lowering' methods for the TorchScript backend only
  45. # <DispatchKey>NativeFunctions.cpp defines implementations of native functions which perform lazy tracing
  46. # - opt-in: 'full_codegen' section of backend yaml; 'supported' section omits these implementations
  47. # <DispatchKey>NativeFunctions.h declares implementations of native functions for both 'supported' and 'full_codegen'
  48. # ops
  49. #
  50. # Register<DispatchKey>.cpp registers all op implementations with the dispatcher
  51. # RegisterAutograd<DispatchKey>.cpp registers all autograd implementations with the dispatcher
  52. #
  53. # Validation Helpers:
  54. # - Shape Inference: errs if any ops in backend yaml require shape inference not provided by meta kernels or
  55. # implementations in torch/csrc/lazy/core/shape_inference.*
  56. # - native function impls: errs if any 'supported' ops do not have an implementation defined in the backend
  57. # (non-codegen) implementation file
  58. #
  59. #
  60. # About the Data Model
  61. # ~~~~~~~~~~~~~~~~~~~~
  62. #
  63. # Modeled after ATen codegen, the first step is to parse yaml and build a data model for the operators
  64. # we care about. In this case, the <backend>_native_functions yaml defines a subset of the core operators
  65. # (defined in more detail in the main native_functions.yaml), which will be supported by your backend.
  66. # Backends can list ops in two categories:
  67. # - `supported` ops require hand-implementations but still get codegenned declarations and registrations
  68. # - `full_codegen` ops get implementations (and IR classes) generated too
  69. #
  70. # Each native function is modeled as an object with a schema, and each schema has objects representing their
  71. # arguments. Much of the codegen is manipulation of the arguments and their types. For example, lazy tensor
  72. # backends need to transform 'at::Tensor' arguments into 'lazy::Value' objects, as well as replacing reference
  73. # types (stringref) with actual string objects, and this is done by manipulating the data model objects.
  74. # - see api/lazy.py for the lazy data model
  75. #
  76. # Once the data model is set up, the rest of this script processes a number of templates for output CPP file
  77. # and fills in the template values using helpers in `dest/lazy_ir.py` and `dest/lazy_ts_lowering.py`. These
  78. # helpers mostly iterate over functions and their arguments, outputting different c++ snippets.
  79. #
  80. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  81. # Parses the external backend's yaml, and adds a new BackendIndex for the backend's dispatch key.
  82. # Returns a Tuple of (backend_key, autograd_key, cpp_namespace, updated BackendIndex mapping, full_codegen)
  83. ParsedExternalYaml = namedtuple(
  84. "ParsedExternalYaml",
  85. ["backend_key", "autograd_key", "cpp_namespace", "backend_indices", "full_codegen"],
  86. )
  87. def parse_native_functions_keys(
  88. backend_yaml_path: str,
  89. grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
  90. ) -> tuple[list[OperatorName], list[Any], list[OperatorName]]:
  91. with open(backend_yaml_path) as f:
  92. yaml_values = yaml.load(f, Loader=YamlLoader)
  93. if not isinstance(yaml_values, dict):
  94. raise AssertionError(f"Expected dict from YAML, got {type(yaml_values)}")
  95. full_codegen = yaml_values.pop("full_codegen", [])
  96. non_native = yaml_values.pop("non_native", [])
  97. ir_gen = yaml_values.pop("ir_gen", [])
  98. if not isinstance(full_codegen, list):
  99. raise AssertionError(
  100. f"Expected full_codegen to be list, got {type(full_codegen)}"
  101. )
  102. if not isinstance(non_native, list):
  103. raise AssertionError(f"Expected non_native to be list, got {type(non_native)}")
  104. if not isinstance(ir_gen, list):
  105. raise AssertionError(f"Expected ir_gen to be list, got {type(ir_gen)}")
  106. full_codegen_opnames = [OperatorName.parse(name) for name in full_codegen]
  107. ir_gen_opnames = [OperatorName.parse(name) for name in ir_gen]
  108. return full_codegen_opnames, non_native, ir_gen_opnames
  109. def validate_shape_inference_header(
  110. shape_inference_hdr: str, expected_shape_infr_decls: list[str]
  111. ) -> None:
  112. try:
  113. with open(shape_inference_hdr) as f:
  114. shape_infr_decls = f.read()
  115. shape_infr_decl_lines = set(shape_infr_decls.split("\n"))
  116. except OSError as e:
  117. raise AssertionError(
  118. f"Unable to read from the specified shape_inference_hdr file: {shape_inference_hdr}"
  119. ) from e
  120. # TODO(whc) add a check for shape inference functions that have meta kernels implement and should be retired.
  121. missing_decls = [
  122. decl for decl in expected_shape_infr_decls if decl not in shape_infr_decl_lines
  123. ]
  124. if missing_decls:
  125. raise Exception( # noqa: TRY002
  126. f"""Missing shape inference function.\n
  127. Please add declare this function in {shape_inference_hdr}:\n
  128. and implement it in the corresponding shape_inference.cpp file.\n
  129. {os.linesep.join(missing_decls)}"""
  130. )
  131. # Some helper functions for the codegen.
  132. def get_ltc_helper_fns() -> str:
  133. return """\
  134. at::Tensor to_meta(const at::Tensor& tensor) {
  135. // undefined tensors can't be converted to the meta device, since they don't have sizes/strides
  136. if (!tensor.defined()) return tensor;
  137. auto out = at::native::empty_strided_meta_symint(tensor.sym_sizes(), tensor.sym_strides(), \
  138. /*dtype=*/tensor.scalar_type(), /*layout=*/tensor.layout(), \
  139. /*device=*/c10::Device(c10::kMeta), /*pin_memory=*/std::nullopt);
  140. // needs to handle wrapped numbers, so dtype promotion works properly.
  141. if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
  142. out.unsafeGetTensorImpl()->set_wrapped_number(true);
  143. }
  144. return out;
  145. }
  146. std::optional<at::Tensor> to_meta(const std::optional<at::Tensor>& tensor) {
  147. if (tensor.has_value()) {
  148. return to_meta(*tensor);
  149. }
  150. return std::nullopt;
  151. }
  152. std::vector<at::Tensor> to_meta(at::ITensorListRef t_list) {
  153. std::vector<at::Tensor> outs;
  154. outs.reserve(t_list.size());
  155. for (const auto& tensor : t_list) {
  156. outs.push_back(to_meta(tensor));
  157. }
  158. return outs;
  159. }
  160. """
  161. class default_args:
  162. node_base: str = "Node"
  163. node_base_hdr: str | None = None
  164. shape_inference_hdr: str = "torch/csrc/lazy/core/shape_inference.h"
  165. tensor_class: str = "torch::lazy::LazyTensor"
  166. tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h"
  167. lazy_ir_generator: type[GenLazyIR] = GenLazyIR
  168. native_func_definition_generator: type[GenLazyNativeFuncDefinition] = (
  169. GenLazyNativeFuncDefinition
  170. )
  171. backend_name: str = "TorchScript"
  172. def main() -> None:
  173. parser = argparse.ArgumentParser(description="Generate Lazy Tensor backend files")
  174. parser.add_argument(
  175. "-s",
  176. "--source-yaml",
  177. "--source_yaml",
  178. help="path to source yaml file containing operator external definitions",
  179. )
  180. parser.add_argument("-o", "--output-dir", "--output_dir", help="output directory")
  181. parser.add_argument(
  182. "--dry-run", "--dry_run", type=bool, default=False, help="output directory"
  183. )
  184. parser.add_argument(
  185. "--impl-path",
  186. "--impl_path",
  187. type=str,
  188. default=None,
  189. help="path to the source C++ file containing kernel definitions",
  190. )
  191. parser.add_argument(
  192. "--gen-ts-lowerings",
  193. "--gen_ts_lowerings",
  194. action="store_true",
  195. help="Generate TorchScript lowerings in addition to Lazy IR and NativeFunctions",
  196. )
  197. parser.add_argument(
  198. "--node-base",
  199. "--node_base",
  200. type=str,
  201. default=default_args.node_base,
  202. help="Name of backend specific custom Lazy IR Node base class",
  203. )
  204. parser.add_argument(
  205. "--node-base-hdr",
  206. "--node_base_hdr",
  207. type=str,
  208. default=default_args.node_base_hdr,
  209. help="Path to header file defining custom Lazy IR Node base class",
  210. )
  211. parser.add_argument(
  212. "--shape-inference-hdr",
  213. "--shape_inference_hdr",
  214. type=str,
  215. default=default_args.shape_inference_hdr,
  216. help="Path to header file defining custom Lazy shape inference functions",
  217. )
  218. parser.add_argument(
  219. "--tensor-class",
  220. "--tensor_class",
  221. type=str,
  222. default=default_args.tensor_class,
  223. help="Name of backend specific custom Lazy Tensor class",
  224. )
  225. parser.add_argument(
  226. "--tensor-class-hdr",
  227. "--tensor_class_hdr",
  228. type=str,
  229. default=default_args.tensor_class_hdr,
  230. help="Path to header file defining custom Lazy Tensor class",
  231. )
  232. parser.add_argument(
  233. "--backend-name",
  234. "--backend_name",
  235. type=str,
  236. default=default_args.backend_name,
  237. help="Name of the backend to generate",
  238. )
  239. options = parser.parse_args()
  240. # Assumes that this file lives at PYTORCH_ROOT/torchgen/gen_backend_stubs.py
  241. torch_root = Path(__file__).absolute().parents[2]
  242. aten_path = str(torch_root / "aten" / "src" / "ATen")
  243. lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator
  244. if options.gen_ts_lowerings:
  245. lazy_ir_generator = GenTSLazyIR
  246. native_func_definition_generator: type[GenLazyNativeFuncDefinition] = (
  247. default_args.native_func_definition_generator
  248. )
  249. run_gen_lazy_tensor(
  250. aten_path,
  251. options.source_yaml,
  252. options.output_dir,
  253. options.dry_run,
  254. options.impl_path,
  255. options.node_base,
  256. options.node_base_hdr,
  257. options.tensor_class,
  258. options.tensor_class_hdr,
  259. options.shape_inference_hdr,
  260. lazy_ir_generator,
  261. native_func_definition_generator,
  262. options.backend_name,
  263. )
  264. def run_gen_lazy_tensor(
  265. aten_path: str,
  266. source_yaml: str,
  267. output_dir: str,
  268. dry_run: bool,
  269. impl_path: str | None,
  270. node_base: str = default_args.node_base,
  271. node_base_hdr: str | None = default_args.node_base_hdr,
  272. tensor_class: str = default_args.tensor_class,
  273. tensor_class_hdr: str = default_args.tensor_class_hdr,
  274. shape_inference_hdr: str = default_args.shape_inference_hdr,
  275. lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator,
  276. native_func_definition_generator: type[
  277. GenLazyNativeFuncDefinition
  278. ] = default_args.native_func_definition_generator,
  279. # build_in_tree is true for TS backend and affects include paths
  280. build_in_tree: bool = False,
  281. # per_operator_headers changes whether ATen/Functions.h or individual operator headers are used
  282. # it must match how ATen was built
  283. per_operator_headers: bool = False,
  284. backend_name: str = default_args.backend_name,
  285. gen_forced_fallback_code: bool = False,
  286. use_lazy_shape: bool = True,
  287. # the following arguments are temporary customization points for xla backend migration.
  288. # do not rely on them otherwise, they should be removed once migration is complete
  289. backend_namespace: str = "torch::lazy",
  290. get_tensorlist: str = "GetTensorList",
  291. get_tensor_or_wrap_number: str = "GetLtcTensorOrCreateForWrappedNumber",
  292. try_get_tensor: str = "TryGetLtcTensor",
  293. metrics_counter: str = 'TORCH_LAZY_FN_COUNTER("lazy::")',
  294. create_tensor: str = "LazyTensor::Create",
  295. create_from_first_tensor: bool = False,
  296. create_aten_from_ltc_tensor: str = "torch::lazy::CreateAtenFromLtcTensor",
  297. tuple_aten_from_ltc_tensors: str = "torch::lazy::TupleAtenFromLtcTensors",
  298. lazy_value_class: str = "torch::lazy::Value",
  299. lazy_tensor_ptr: str = "LazyTensorPtr",
  300. get_device_fn: str = "torch::lazy::GetBackendDevice",
  301. ) -> None:
  302. lv_tokens = lazy_value_class.split("::")
  303. lv_class = lv_tokens[-1]
  304. lv_ns = "::".join(lv_tokens[:-1])
  305. setValueT(BaseCppType(lv_ns, lv_class))
  306. template_dir = os.path.join(aten_path, "templates")
  307. def make_file_manager(install_dir: str) -> FileManager:
  308. return FileManager(
  309. install_dir=install_dir, template_dir=template_dir, dry_run=dry_run
  310. )
  311. fm = make_file_manager(output_dir)
  312. native_yaml_path = os.path.join(aten_path, "native/native_functions.yaml")
  313. tags_yaml_path = os.path.join(aten_path, "native/tags.yaml")
  314. parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path)
  315. native_functions, backend_indices = (
  316. parsed_yaml.native_functions,
  317. parsed_yaml.backend_indices,
  318. )
  319. grouped_native_functions = get_grouped_native_functions(native_functions)
  320. def sort_native_function(f: NativeFunctionsGroup | NativeFunction) -> str:
  321. """
  322. We sort the native function because of the note in concat_map_codegen.
  323. TODO(alanwaketan): Remove this sorting hack once all ops are grouped properly.
  324. """
  325. func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
  326. return str(func.name.name)
  327. grouped_native_functions = sorted(
  328. grouped_native_functions, key=sort_native_function
  329. )
  330. parsed_backend_yaml = parse_backend_yaml(
  331. source_yaml, grouped_native_functions, backend_indices
  332. )
  333. backend_key = parsed_backend_yaml.backend_key
  334. autograd_key = parsed_backend_yaml.autograd_key
  335. cpp_namespace = parsed_backend_yaml.cpp_namespace
  336. backend_indices = parsed_backend_yaml.backend_indices
  337. # the following 3 keys are all processed differently
  338. # for full_codegen, we generate IR, kernels, etc
  339. # for ir_gen, we generate only IR
  340. # non_native is used to register kernels not declared in
  341. # native_functions.yaml
  342. full_codegen, non_native, ir_gen = parse_native_functions_keys(
  343. source_yaml, grouped_native_functions
  344. )
  345. def concat_map_codegen(
  346. func: Callable[[NativeFunction], Sequence[str]],
  347. xs: Iterable[NativeFunctionsGroup | NativeFunction],
  348. ops_list: list[OperatorName] = full_codegen,
  349. ) -> Iterator[str]:
  350. """
  351. We code-gen for the functional variant, which is all we need for IR classes/lowerings/shape inferences, but we
  352. only code-gen additional entries for the inplace variant for the native functions.
  353. """
  354. for x in xs:
  355. fs = list(x.functions()) if isinstance(x, NativeFunctionsGroup) else [x]
  356. for f in fs:
  357. if f.func.name in ops_list:
  358. yield from func(f)
  359. selector = SelectiveBuilder.get_nop_selector()
  360. if backend_key is None:
  361. raise AssertionError("backend_key must be non-None")
  362. class_name = backend_indices[backend_key].native_function_class_name()
  363. if impl_path is not None:
  364. error_on_missing_kernels(
  365. native_functions,
  366. backend_indices,
  367. backend_key,
  368. autograd_key,
  369. class_name,
  370. impl_path,
  371. full_codegen,
  372. )
  373. """ Validate Shape Inference Definitions
  374. Generated lazy native functions all perform shape inference, by first using a meta:: kernel
  375. if available for that op, and otherwise using a 'compute_shape_{op}' function instead. The generator
  376. knows the call signature for compute_shape_{op} because it matches the nativefunction (and meta::) signature,
  377. so it just has to check whether the op is structured and generate a call for one or the other. It's up to the dev
  378. to supply the missing compute_shape_{op} function, but the codegen at least warns you about this and provides
  379. the expected signature which can be copy-pasted into shape_inference.h.
  380. compute_shape_{op} functions are handwritten and should be replaced over time as ops get ported
  381. to structured kernels.
  382. See torch/csrc/lazy/core/shape_inference.cpp #READ THIS! for more information.
  383. """
  384. if shape_inference_hdr is not None:
  385. expected_shape_infr_decls = list(
  386. concat_map_codegen(
  387. dest.GenLazyShapeInferenceDefinition(
  388. backend_indices[backend_key], tensor_class
  389. ),
  390. grouped_native_functions,
  391. )
  392. )
  393. validate_shape_inference_header(shape_inference_hdr, expected_shape_infr_decls)
  394. if class_name is None:
  395. raise AssertionError("class_name must be non-None")
  396. # Generate nativefunction declarations
  397. # Note, eager registrations is set to False for the lazy TS backend as another LTC backend
  398. # may want to register their own lazy kernels instead of registering the TS ones.
  399. # The registration will lazily happen when init_ts_backend is called.
  400. gen_dispatchkey_nativefunc_headers(
  401. fm,
  402. class_name,
  403. cpp_namespace,
  404. backend_indices,
  405. grouped_native_functions,
  406. backend_key,
  407. autograd_key,
  408. backend_name,
  409. )
  410. # Generate Dispatcher registrations which hook up the nativefunctions
  411. for dispatch_key in (
  412. [backend_key] if autograd_key is None else [backend_key, autograd_key]
  413. ):
  414. gen_dispatcher_registrations(
  415. fm,
  416. output_dir,
  417. class_name,
  418. backend_indices,
  419. grouped_native_functions,
  420. backend_key,
  421. dispatch_key,
  422. selector,
  423. build_in_tree=build_in_tree,
  424. per_operator_headers=per_operator_headers,
  425. backend_name=backend_name,
  426. eager_registration=False,
  427. )
  428. # Generate native function impls that build IR nodes
  429. ns_helper = NamespaceHelper(cpp_namespace)
  430. fm.write_with_template(
  431. f"{backend_key}NativeFunctions.cpp",
  432. "DispatchKeyNativeFunctions.cpp",
  433. lambda: {
  434. "includes": [
  435. f"#include <{path}>"
  436. for path in [
  437. tensor_class_hdr,
  438. shape_inference_hdr,
  439. "ATen/Functions.h",
  440. "ATen/native/TensorConversions.h",
  441. "ATen/NativeFunctions.h",
  442. "ATen/CompositeExplicitAutogradNonFunctionalFunctions.h",
  443. "ATen/MetaFunctions.h",
  444. "ATen/Operators.h",
  445. "ATen/native/CPUFallback.h",
  446. "torch/csrc/lazy/core/ir_builder.h",
  447. "torch/csrc/lazy/core/lazy_graph_executor.h",
  448. "torch/csrc/lazy/core/metrics.h",
  449. "torch/csrc/lazy/core/shape.h",
  450. f"{output_dir}/{backend_key}NativeFunctions.h",
  451. f"{output_dir}/LazyIr.h",
  452. ]
  453. + (
  454. ["torch/csrc/lazy/ts_backend/ts_eager_fallback.h"]
  455. if gen_forced_fallback_code
  456. else []
  457. )
  458. ],
  459. "helper_fns": get_ltc_helper_fns(),
  460. "native_functions_include": "",
  461. "namespace_prologue": ns_helper.prologue,
  462. "namespace_epilogue": ns_helper.epilogue,
  463. "native_function_definitions": list(
  464. concat_map_codegen(
  465. native_func_definition_generator(
  466. f"{backend_key}NativeFunctions",
  467. backend_indices[backend_key],
  468. tensor_class,
  469. gen_forced_fallback_code,
  470. backend_namespace,
  471. get_tensorlist,
  472. get_tensor_or_wrap_number,
  473. try_get_tensor,
  474. metrics_counter,
  475. create_tensor,
  476. create_from_first_tensor,
  477. create_aten_from_ltc_tensor,
  478. tuple_aten_from_ltc_tensors,
  479. lazy_tensor_ptr,
  480. get_device_fn,
  481. ),
  482. grouped_native_functions,
  483. )
  484. ),
  485. },
  486. )
  487. # Generate IR node classes
  488. lazy_ir_obj = lazy_ir_generator(
  489. backend_indices[backend_key], backend_name, node_base, use_lazy_shape
  490. )
  491. fm.write_with_template(
  492. "LazyIr.h",
  493. "LazyIr.h",
  494. lambda: {
  495. "lazy_ir_sysinc": [
  496. f"#include <{path}>"
  497. for path in [
  498. "ATen/core/Formatting.h",
  499. "c10/core/ScalarType.h",
  500. "torch/csrc/lazy/core/hash.h",
  501. "torch/csrc/lazy/core/ir.h",
  502. "torch/csrc/lazy/core/shape.h",
  503. "optional",
  504. "vector",
  505. ]
  506. ],
  507. "lazy_ir_inc": [f'#include "{node_base_hdr}"']
  508. if node_base_hdr is not None
  509. else [],
  510. "ir_declarations": list(
  511. concat_map_codegen(
  512. lazy_ir_obj, grouped_native_functions, full_codegen + ir_gen
  513. )
  514. ),
  515. "namespace_prologue": ns_helper.prologue,
  516. "namespace_epilogue": ns_helper.epilogue,
  517. },
  518. )
  519. # Generate Non Native IR Node classes
  520. fm.write_with_template(
  521. "LazyNonNativeIr.h",
  522. "LazyNonNativeIr.h",
  523. lambda: {
  524. "lazy_non_native_ir_inc": [
  525. f"#include <{path}>"
  526. for path in [
  527. "torch/csrc/lazy/core/ir.h",
  528. "torch/csrc/lazy/core/ir_builder.h",
  529. "torch/csrc/lazy/core/internal_ops/ltc_ops.h",
  530. "torch/csrc/lazy/core/shape_inference.h",
  531. ]
  532. + ([node_base_hdr] if node_base_hdr else [])
  533. if path
  534. ],
  535. "non_native_ir_nodes": dest.generate_non_native_lazy_ir_nodes(
  536. non_native, lazy_ir_obj
  537. ),
  538. "namespace_prologue": ns_helper.prologue,
  539. "namespace_epilogue": ns_helper.epilogue,
  540. },
  541. )
  542. if __name__ == "__main__":
  543. main()