lazy_ir.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718
  1. from __future__ import annotations
  2. import itertools
  3. from abc import ABC
  4. from dataclasses import dataclass
  5. from typing import Any
  6. import torchgen.api.dispatcher as dispatcher
  7. from torchgen.api.lazy import (
  8. getValueT,
  9. isValueType,
  10. LazyArgument,
  11. LazyIrProperties,
  12. LazyIrSchema,
  13. tensorListValueT,
  14. )
  15. from torchgen.api.translate import translate
  16. from torchgen.api.types import (
  17. BaseCType,
  18. Binding,
  19. deviceT,
  20. DispatcherSignature,
  21. kernel_signature,
  22. NativeSignature,
  23. OptionalCType,
  24. VectorCType,
  25. )
  26. from torchgen.context import method_with_native_function
  27. from torchgen.dest.lazy_ts_lowering import ts_lowering_body
  28. from torchgen.model import (
  29. Argument,
  30. BackendIndex,
  31. BackendMetadata,
  32. BaseTy,
  33. BaseType,
  34. FunctionSchema,
  35. ListType,
  36. NativeFunction,
  37. NativeFunctionsGroup,
  38. )
  39. def node_ctor_arg_rvalue_string(arg: LazyArgument) -> str:
  40. """
  41. Given a LazyArgument,
  42. generate a c++ string for materializing an rvalue of that arg for passing into
  43. a lazy Node constructor.
  44. """
  45. # TODO: Matching on CType seems wrong; should be matching on Type
  46. if isValueType(arg.lazy_type):
  47. if isinstance(arg.lazy_type, BaseCType):
  48. if arg.is_wrapped_scalar:
  49. return f"node_{arg.name}"
  50. elif arg.lazy_type.type is tensorListValueT:
  51. return f"lazy_{arg.name}_tensorlist"
  52. elif arg.is_symint_or_list:
  53. return f"GetSymIntValue({arg.name})"
  54. return f"lazy_{arg.name}->GetIrValue()"
  55. elif isinstance(arg.lazy_type, OptionalCType):
  56. if arg.is_symint_or_list:
  57. # TODO: I don't understand when you should put lazy_ in the name
  58. # or not
  59. return f"{arg.name} ? std::make_optional(GetSymIntValue(*{arg.name})) : ::std::nullopt"
  60. elif arg.is_wrapped_scalar:
  61. return f"node_{arg.name}"
  62. return (
  63. f"lazy_{arg.name} ? "
  64. f"std::make_optional(lazy_{arg.name}->GetIrValue()) : "
  65. "::std::nullopt"
  66. )
  67. else:
  68. raise AssertionError(
  69. f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
  70. )
  71. else:
  72. # NB: this is here because right now we aren't treating SymInt[] as a
  73. # value type; when we do this needs to move above
  74. # NB: we cannot test arg.lazy_type as we've already specified it is an
  75. # int64_t and so we cannot distinguish between SymInt and int64_t
  76. if isinstance(arg.orig_type, ListType) and arg.orig_type.elem == BaseType(
  77. BaseTy.SymInt
  78. ):
  79. if arg.symint:
  80. return f"GetSymIntArrayRefValue({arg.name})"
  81. else:
  82. return f"std::vector<int64_t>({arg.name}.begin(), {arg.name}.end())"
  83. elif isinstance(arg.lazy_type, VectorCType) and isinstance(
  84. arg.lazy_type.elem, BaseCType
  85. ):
  86. return f"std::vector<{arg.lazy_type.elem.type}>({arg.name}.begin(), {arg.name}.end())"
  87. elif (
  88. isinstance(arg.lazy_type, OptionalCType)
  89. and isinstance(arg.lazy_type.elem, VectorCType)
  90. and isinstance(arg.lazy_type.elem.elem, BaseCType)
  91. ):
  92. return f"torch::lazy::ToOptionalVector<{arg.lazy_type.elem.elem.type}>({arg.name})"
  93. else:
  94. return f"{arg.name}"
  95. def node_ctor_inputs(schema: LazyIrSchema) -> str:
  96. """
  97. Produce a formatted string with the arguments as passed into the constructor of a node class.
  98. """
  99. node_ctor_values = [
  100. node_ctor_arg_rvalue_string(arg) for arg in schema.filtered_args()
  101. ]
  102. return ", ".join(node_ctor_values)
  103. def gen_fallback_code(
  104. schema: LazyIrSchema,
  105. sig: DispatcherSignature | NativeSignature,
  106. overload_name: str,
  107. ) -> str:
  108. """
  109. Generate code that falls back to eager conditioned on a predicate
  110. """
  111. dispatcher_sig = DispatcherSignature.from_schema(schema.func)
  112. exprs = translate(sig.arguments(), dispatcher_sig.arguments())
  113. fallback_args = ",\n ".join([a.expr for a in exprs])
  114. if len(overload_name):
  115. aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
  116. else:
  117. aten_op_str = f"ATEN_OP({schema.aten_name})"
  118. return f"""
  119. if (force_eager_fallback({aten_symbol(schema)})) {{
  120. return at::native::call_fallback_fn_symint<&ltc_eager_fallback, {aten_op_str}>::call(
  121. {fallback_args}
  122. );
  123. }}
  124. """
  125. def aten_symbol(schema: LazyIrSchema) -> str:
  126. missing_interned_strings = {
  127. "sigmoid_backward",
  128. }
  129. if schema.aten_name in missing_interned_strings:
  130. return f'c10::Symbol::fromQualString("aten::{schema.aten_name}")'
  131. if not schema.aten_name.startswith("at::"):
  132. return f"at::aten::{schema.aten_name}"
  133. else:
  134. return schema.aten_name
  135. # converts all tensor-like arguments to meta tensors. Returns:
  136. # (1) a string containing all of the logic that does the conversions.
  137. # (2) a context, to be used by translate(), with all of the relevant bindings.
  138. def convert_to_meta_tensors(sig: DispatcherSignature) -> tuple[str, list[Binding]]:
  139. context: list[Binding] = []
  140. unwrapped_tensor_args: list[str] = []
  141. for arg in sig.arguments():
  142. if isinstance(arg.argument, Argument) and arg.argument.type.is_tensor_like():
  143. unwrapped_name = f"{arg.name}_meta"
  144. unwrapped_tensor_args.append(
  145. f"auto {unwrapped_name} = to_meta({arg.name});"
  146. )
  147. context.append(arg.with_name(unwrapped_name))
  148. else:
  149. context.append(arg)
  150. unwrap_tensor_args_str = "\n ".join(unwrapped_tensor_args)
  151. return unwrap_tensor_args_str, context
  152. @dataclass(frozen=True)
  153. class GenLazyIR(ABC):
  154. backend_index: BackendIndex
  155. backend_name: str
  156. node_base: str
  157. use_lazy_shape: bool
  158. @method_with_native_function
  159. def __call__(self, f: NativeFunctionsGroup | NativeFunction) -> list[str]:
  160. func = f.functional.func if isinstance(f, NativeFunctionsGroup) else f.func
  161. metadata = self.backend_index.get_kernel(
  162. f.functional if isinstance(f, NativeFunctionsGroup) else f
  163. )
  164. schema = LazyIrSchema(
  165. func, symint=metadata is not None and metadata.supports_symint()
  166. )
  167. return self.gen(schema)
  168. # there is no lowering functionality generated unless this IR base class is subclassed and
  169. # implemented as a backend-specific node
  170. def lowering_function(self, schema: LazyIrSchema) -> str:
  171. return ""
  172. def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
  173. return ""
  174. def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
  175. return f"""bool CanBeReused({node_ctor_args}) const {{
  176. return false;
  177. }}"""
  178. def node_base_ctor_call(self, schema: LazyIrSchema) -> str:
  179. value_args = schema.filtered_args(values=True, scalars=False)
  180. # backends can customize the way the node base class constructor is called,
  181. # as long as all of its arguments can be generated from information available from the schema
  182. base_ctor_value_args_list = []
  183. for arg in value_args:
  184. if isinstance(arg.lazy_type, (BaseCType, VectorCType)):
  185. base_ctor_value_args_list.append(f"{arg.name}")
  186. elif isinstance(arg.lazy_type, OptionalCType):
  187. base_ctor_value_args_list.append(f"{arg.name}.value_or(kNullValue)")
  188. else:
  189. raise AssertionError(
  190. f"Unsupported type ({arg.lazy_type}) - add support if necessary"
  191. )
  192. base_ctor_value_args = ", ".join(base_ctor_value_args_list)
  193. scalar_args = schema.filtered_args(values=False, scalars=True)
  194. # Shape construction.
  195. # Conditionally build shape depending on specified shape property
  196. if schema.properties.ShapePrecompute:
  197. shape_ctor_arg = "std::move(shapes),"
  198. elif schema.properties.ShapeCompute:
  199. shape_args = [a.name for a in value_args]
  200. shape_args.extend(a.name for a in scalar_args)
  201. shape_ctor_arg = f"compute_shape_{schema.name}({', '.join(shape_args)}),"
  202. elif schema.properties.ShapeCache:
  203. shape_args = [f"operand({i})" for i in range(len(value_args))]
  204. shape_args.extend(a.name for a in scalar_args)
  205. shape_ctor_arg = f"[&](){{ return compute_shape_{schema.name}({', '.join(shape_args)})[0]; }},"
  206. else:
  207. shape_ctor_arg = ""
  208. scalar_hashes = ", ".join(f"{a.name}" for a in scalar_args)
  209. return f"""{self.node_base}(
  210. {schema.node_name}::ClassOpKind(),
  211. OpList{{{base_ctor_value_args}}},
  212. {shape_ctor_arg}
  213. /* num_outputs */ {len(schema.returns)},
  214. torch::lazy::MHash({scalar_hashes}))"""
  215. def gen(self, schema: LazyIrSchema) -> list[str]:
  216. opkind = schema.opkind or aten_symbol(schema)
  217. # for now, we just want one IR class decl and soon after also the method defs
  218. # and we use the functional version not out/inplace.
  219. all_args = schema.filtered_args()
  220. scalar_args = schema.filtered_args(values=False, scalars=True)
  221. ctor_args = [f"const {i.lazy_type.cpp_type()}& {i.name}" for i in all_args]
  222. reuse_ctor_args = ", ".join(ctor_args)
  223. if self.use_lazy_shape and schema.properties.ShapePrecompute:
  224. ctor_args.append("std::vector<torch::lazy::Shape>&& shapes")
  225. node_ctor_args = ", ".join(ctor_args)
  226. scalar_initializers = ",\n ".join(
  227. [
  228. # This code is just special casing the mapping from string_view -> strings
  229. f"{a.name}({a.name}.has_value() ? ::std::make_optional(std::string(*{a.name})) : ::std::nullopt)"
  230. if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
  231. else f"{a.name}({a.name})"
  232. for a in scalar_args
  233. ]
  234. )
  235. if len(scalar_initializers):
  236. scalar_initializers = f",\n {scalar_initializers}"
  237. scalar_decls = "\n ".join(
  238. [
  239. f"std::string {a.name};"
  240. if a.lazy_type.cpp_type() == "c10::string_view"
  241. else f"::std::optional<std::string> {a.name};"
  242. if a.lazy_type.cpp_type() == "::std::optional<c10::string_view>"
  243. else f"{a.lazy_type.cpp_type()} {a.name};"
  244. for a in scalar_args
  245. ]
  246. )
  247. optional_values = [
  248. arg.name
  249. for arg in schema.filtered_args(values=True, scalars=False)
  250. if isinstance(arg.lazy_type, OptionalCType)
  251. ]
  252. has_optional_decls = "\n ".join(
  253. [f"bool has_{value}: 1;" for value in optional_values]
  254. )
  255. has_optional_defs = "\n ".join(
  256. [f"has_{value} = !!{value};" for value in optional_values]
  257. )
  258. members_to_string = []
  259. for arg in scalar_args:
  260. if isinstance(arg.lazy_type, OptionalCType):
  261. value = f"{arg.name}.value()"
  262. if arg.is_generator:
  263. value = '"torch.Generator()"'
  264. members_to_string.append(
  265. f"""if ({arg.name}.has_value()) {{
  266. ss << ", {arg.name}=" << {value};
  267. }} else {{
  268. ss << ", {arg.name}=null";
  269. }}"""
  270. )
  271. else:
  272. members_to_string.append(f'ss << ", {arg.name}=" << {arg.name};')
  273. members_to_string_str = "\n ".join(members_to_string)
  274. return [
  275. f"""\
  276. class {schema.node_name} : public {self.node_base} {{
  277. public:
  278. static torch::lazy::OpKind ClassOpKind() {{
  279. return torch::lazy::OpKind({opkind});
  280. }}
  281. {schema.node_name}({node_ctor_args})
  282. : {self.node_base_ctor_call(schema)}{scalar_initializers}
  283. {{
  284. {has_optional_defs}
  285. }}
  286. std::string ToString() const override {{
  287. std::stringstream ss;
  288. ss << {self.node_base}::ToString();
  289. {members_to_string_str}
  290. return ss.str();
  291. }}
  292. {self.create_function(schema, reuse_ctor_args)}
  293. {self.can_be_reused_function(schema, reuse_ctor_args)}
  294. {self.lowering_function(schema)}
  295. {scalar_decls}
  296. {has_optional_decls}
  297. }};
  298. """,
  299. ]
  300. @dataclass(frozen=True)
  301. class GenTSLazyIR(GenLazyIR):
  302. def lowering_function(self, schema: LazyIrSchema) -> str:
  303. signature = """
  304. torch::lazy::TSOpVector Lower(
  305. std::shared_ptr<torch::jit::GraphFunction> function,
  306. torch::lazy::TSLoweringContext* loctx) const override"""
  307. if schema.properties.LowerDeclOnly:
  308. return f"{signature};"
  309. elif schema.properties.Lower:
  310. return f"""{signature} {{
  311. {ts_lowering_body(schema)}
  312. }}
  313. """
  314. else:
  315. return ""
  316. def create_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
  317. signature = f"static NodePtr Create({node_ctor_args})"
  318. if schema.properties.CreateFnDeclOnly:
  319. return f"{signature};"
  320. elif not schema.properties.CreateFn:
  321. return ""
  322. return f"""{signature} {{
  323. return ReuseOrMakeNode<{schema.node_name}>(data);
  324. }}"""
  325. def can_be_reused_function(self, schema: LazyIrSchema, node_ctor_args: str) -> str:
  326. signature = f"bool CanBeReused({node_ctor_args}) const"
  327. if schema.properties.CanBeReusedDeclOnly:
  328. return f"{signature};"
  329. elif not schema.properties.CanBeReused:
  330. return ""
  331. value_comparison = []
  332. for arg in itertools.chain(schema.positional_values, schema.keyword_values):
  333. if isinstance(arg.lazy_type, OptionalCType):
  334. value_comparison.append(
  335. f"nullable_operand(i++) == {arg.name}.value_or(kNullValue)"
  336. )
  337. else:
  338. value_comparison.append(f"operand(i++) == {arg.name}")
  339. for arg in itertools.chain(schema.positional_scalars, schema.keyword_scalars):
  340. if isinstance(arg.lazy_type, OptionalCType):
  341. value_comparison.append(
  342. f"((!this->{arg.name}&&!{arg.name}) || (this->{arg.name}&&{arg.name} && *(this->{arg.name}) == *{arg.name}))"
  343. )
  344. else:
  345. value_comparison.append(f"this->{arg.name} == {arg.name}")
  346. value_comparison_str = " &&\n ".join(value_comparison)
  347. return f"""{signature} {{
  348. size_t i = 0;
  349. return ({value_comparison_str});
  350. }}"""
  351. @dataclass(frozen=True)
  352. class GenLazyNativeFuncDefinition:
  353. class_method_name: str
  354. backend_index: BackendIndex
  355. tensor_class: str
  356. gen_forced_fallback_code: bool
  357. backend_namespace: str
  358. get_tensorlist: str
  359. get_tensor_or_wrap_number: str
  360. try_get_tensor: str
  361. metrics_counter: str
  362. create_tensor: str
  363. create_from_first_tensor: bool
  364. create_aten_from_ltc_tensor: str
  365. tuple_aten_from_ltc_tensors: str
  366. lazy_tensor_ptr: str
  367. get_device_fn: str
  368. def lazy_tensor_decls(self, func: NativeFunction, schema: LazyIrSchema) -> str:
  369. value_args = schema.filtered_args(values=True, scalars=False)
  370. # Generates lazy_{name} variables for LazyTensors wrapping input tensors
  371. lazy_tensor_decls: list[str] = []
  372. for arg in value_args:
  373. if arg.is_wrapped_scalar:
  374. if isinstance(arg.lazy_type, OptionalCType):
  375. lazy_tensor_decls.append(
  376. f"""auto node_{arg.name} = {arg.name} ?
  377. std::make_optional(torch::lazy::LazyGraphExecutor::Get()->
  378. GetIrValueForScalarFromCodegen(*{arg.name}, *common_device)):
  379. ::std::nullopt;"""
  380. )
  381. else:
  382. lazy_tensor_decls.append(
  383. f"""auto node_{arg.name} = torch::lazy::LazyGraphExecutor::Get()->
  384. GetIrValueForScalarFromCodegen({arg.name}, *common_device);"""
  385. )
  386. elif arg.is_symint_or_list:
  387. continue # values are extracted in isValueType
  388. elif isinstance(arg.lazy_type, BaseCType):
  389. if arg.lazy_type.type is tensorListValueT:
  390. lazy_tensor_decls.append(
  391. f"auto lazy_{arg.name}_tensorlist = "
  392. f"{self.backend_namespace}::{self.get_tensorlist}({arg.name});"
  393. )
  394. else:
  395. lazy_tensor_decls.append(
  396. f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
  397. f"{self.backend_namespace}::{self.get_tensor_or_wrap_number}({arg.name}, *common_device);"
  398. )
  399. elif isinstance(arg.lazy_type, OptionalCType):
  400. if arg.lazy_type.elem != BaseCType(getValueT()):
  401. raise AssertionError(
  402. f"Expected OptionalCType elem to be {BaseCType(getValueT())}, "
  403. f"got {arg.lazy_type.elem}"
  404. )
  405. # TODO(alanwaketan): Maybe we want to apply GetLtcTensorOrCreateForWrappedNumber here, but hold it
  406. # until we encounter a real world example.
  407. lazy_tensor_decls.append(
  408. f"{self.lazy_tensor_ptr} lazy_{arg.name} = "
  409. f"{self.backend_namespace}::{self.try_get_tensor}({arg.name}.value_or(at::Tensor()));"
  410. )
  411. else:
  412. raise AssertionError(
  413. f"TODO not sure if there are other valid types to handle here ({arg.lazy_type})"
  414. )
  415. return ("\n ").join(lazy_tensor_decls)
  416. def force_eager_fallback(
  417. self,
  418. func: NativeFunction,
  419. schema: LazyIrSchema,
  420. metadata: BackendMetadata,
  421. sig: DispatcherSignature | NativeSignature,
  422. ) -> str:
  423. if self.gen_forced_fallback_code:
  424. return gen_fallback_code(
  425. schema, sig, overload_name=func.func.name.overload_name
  426. )
  427. return ""
  428. def metrics(self, func: NativeFunction, schema: LazyIrSchema) -> str:
  429. return f"{self.metrics_counter};"
  430. def get_device(self, func: NativeFunction, schema: LazyIrSchema) -> str:
  431. value_args = schema.filtered_args(values=True, scalars=False)
  432. scalar_args = schema.filtered_args(values=False, scalars=True)
  433. value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
  434. optional_device = OptionalCType(BaseCType(deviceT))
  435. optional_devices = [
  436. a.name for a in scalar_args if a.lazy_type == optional_device
  437. ]
  438. if len(value_types_names) == 0 and len(optional_devices) == 0:
  439. raise AssertionError("Expected at least one Value or Device type")
  440. get_device_str = (
  441. f"{self.get_device_fn}({', '.join(value_types_names + optional_devices)})"
  442. )
  443. return f"""auto common_device = {get_device_str};
  444. TORCH_INTERNAL_ASSERT(common_device);
  445. """
  446. def shape_inference(self, func: NativeFunction, schema: LazyIrSchema) -> str:
  447. metadata = self.backend_index.get_kernel(func)
  448. if metadata is None:
  449. raise AssertionError(f"No kernel metadata found for {func.func.name}")
  450. all_args = schema.filtered_args()
  451. returns_length = len(schema.returns)
  452. # call the meta kernel if it exists, to compute output shape/dtype for our IR
  453. # Note [Generated LTC Shape Functions]
  454. # LTC uses meta tensors from core to do shape inference when possible, and otherwise
  455. # we generate a shape function declaration that needs to be manually implemented.
  456. # How do we detect which ops are eligible to use meta tensors?
  457. # In general we should be able to use meta tensors not just on structured operators,
  458. # but also on composite operators that are implemented in terms of structured kernels.
  459. # We don't currently have a way of knowing at codegen time which ops are implemented that way.
  460. # This is the case for all view and view_copy operators however, so we're going to
  461. # use them specifically for all of the view_copy ops (instead of manually writing shape rules for all of them).
  462. is_view_copy_op = "view_copy" in func.tags
  463. is_structured = func.structured or func.structured_delegate is not None
  464. if is_structured or is_view_copy_op:
  465. meta_out = """
  466. std::vector<torch::lazy::Shape> shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())};"""
  467. if returns_length > 1:
  468. def this_shape(i: int) -> str:
  469. return f"torch::lazy::Shape(std::get<{i}>(out_meta).scalar_type(), std::get<{i}>(out_meta).sizes().vec())"
  470. shapes_str = ",".join([this_shape(i) for i in range(returns_length)])
  471. meta_out = "std::vector<torch::lazy::Shape> shapes{" + shapes_str + "};"
  472. # Convert tensor args to the meta device and call it.
  473. # (We can't pass in the input tensors directly, because they are "functional wrappers".
  474. # If any of the meta kernels call a tensor op and redispatch, we don't want to hit the functionalize kernels.)
  475. # Even at::meta:: functions might redispatch, e.g. if they call into view ops.
  476. dispatcher_sig = DispatcherSignature.from_schema(func.func)
  477. meta_conversion_str, meta_call_ctx = convert_to_meta_tensors(dispatcher_sig)
  478. meta_call_args = [
  479. e.expr
  480. for e in translate(
  481. meta_call_ctx, dispatcher_sig.arguments(), method=False
  482. )
  483. ]
  484. if is_view_copy_op:
  485. # view_copy ops always have a CompositeExplicitAutogradNonFunctional kernel
  486. if not func.has_composite_explicit_autograd_non_functional_kernel:
  487. raise AssertionError(
  488. f"view_copy op {func.func.name} must have "
  489. "CompositeExplicitAutogradNonFunctional kernel"
  490. )
  491. dispatch_ns = "compositeexplicitautogradnonfunctional"
  492. else:
  493. dispatch_ns = "meta"
  494. aten_name = schema.aten_name
  495. # TODO: this is trolling
  496. if func.func.has_symint() and metadata.supports_symint():
  497. aten_name += "_symint"
  498. shape_str = f"""\
  499. {meta_conversion_str}
  500. auto out_meta = at::{dispatch_ns}::{aten_name}({", ".join(meta_call_args)});
  501. {meta_out}"""
  502. else:
  503. shape_sig = ComputeShapeSignature(
  504. metadata.kernel, func, symint=metadata.supports_symint()
  505. )
  506. shape_str = f"""
  507. auto shapes = {shape_sig.shape_call};"""
  508. shape_str += f"""
  509. TORCH_INTERNAL_ASSERT(shapes.size() == {returns_length});"""
  510. # Calculating which dimensions are symbolic
  511. func_schema_str = "aten::" + str(func.func)
  512. shape_str += f"""
  513. if(torch::lazy::symbolicShapeEnabled()){{
  514. std::vector<torch::jit::IValue> inputs = {{ {", ".join(str(a.name) for a in all_args)} }};
  515. const char* schema_str = "{func_schema_str}";
  516. applySymbolicShapesOnLT(schema_str, inputs, shapes);
  517. }}
  518. """
  519. return shape_str
  520. def build_ir_node(self, func: NativeFunction, schema: LazyIrSchema) -> str:
  521. node_ctor_input_str = node_ctor_inputs(schema)
  522. return f"""torch::lazy::NodePtr node = torch::lazy::ReuseNode<{schema.node_name}>({node_ctor_input_str});
  523. if (!node) {{
  524. {self.shape_inference(func, schema)}
  525. node = torch::lazy::MakeNode<{schema.node_name}>({node_ctor_input_str}, std::move(shapes));
  526. CacheNode(node);
  527. }}
  528. """
  529. def create_lazy_tensor(self, first_tensor_name: str | None = None) -> str:
  530. # xla uses an instance method for tensor creation, for the time being
  531. if self.create_from_first_tensor:
  532. # TODO(whc) remove this if XLA switches to using static method for creation
  533. if first_tensor_name is None:
  534. raise AssertionError("Requires first tensor to create lazy tensor")
  535. return f"{first_tensor_name}.{self.create_tensor}"
  536. return f"{self.backend_namespace}::{self.create_tensor}"
  537. def return_aten_tensor(self, func: NativeFunction, schema: LazyIrSchema) -> str:
  538. returns_length = len(schema.returns)
  539. value_args = schema.filtered_args(values=True, scalars=False)
  540. value_types_names = [f"{a.name}" for a in value_args if not a.is_wrapped_scalar]
  541. first_tensor_name = value_types_names[0] if len(value_types_names) > 0 else None
  542. bridge_str = f"""auto result = {self.create_aten_from_ltc_tensor}(
  543. {self.create_lazy_tensor(first_tensor_name)}(std::move(node), *common_device));"""
  544. if returns_length > 1:
  545. if len(value_types_names) == 0:
  546. raise AssertionError(
  547. "Code below assumes there is at least one tensor arg"
  548. )
  549. bridge_str = f"""std::vector<{self.lazy_tensor_ptr}> lazy_tensors;
  550. for (int i = 0; i < {returns_length}; i++) {{
  551. lazy_tensors.push_back({self.create_lazy_tensor(first_tensor_name)}({getValueT()}(node, i), *common_device));
  552. }}
  553. auto result = {self.tuple_aten_from_ltc_tensors}<{returns_length}>(lazy_tensors);"""
  554. if schema.name.name.inplace or func.func.is_out_fn():
  555. if returns_length != 1:
  556. raise AssertionError(
  557. "We assumed there was no such case where an op is an in-place variant "
  558. f"and has tuple outputs, but got tuple of len {returns_length}."
  559. )
  560. bridge_str = f"""lazy_{first_tensor_name}->SetInPlaceIrValue(node);
  561. auto& result = {first_tensor_name};"""
  562. bridge_str += """
  563. return result;"""
  564. return bridge_str
  565. @method_with_native_function
  566. def __call__(self, func: NativeFunction) -> list[str]:
  567. sig = kernel_signature(func, self.backend_index)
  568. metadata = self.backend_index.get_kernel(func)
  569. if metadata is None:
  570. raise AssertionError(f"No kernel metadata found for {func.func.name}")
  571. schema = LazyIrSchema(func.func, symint=metadata.supports_symint())
  572. return [
  573. f"""\
  574. {sig.decl(name=f"{self.class_method_name}::{metadata.kernel}")} {{
  575. {self.force_eager_fallback(func, schema, metadata, sig)}
  576. {self.metrics(func, schema)}
  577. {self.get_device(func, schema)}
  578. {self.lazy_tensor_decls(func, schema)}
  579. {self.build_ir_node(func, schema)}
  580. {self.return_aten_tensor(func, schema)}
  581. }}\n
  582. """
  583. ]
  584. class ComputeShapeSignature:
  585. """
  586. Here we use the base name as the suffix of the signature to avoid generating for in-place variants.
  587. """
  588. def __init__(self, kernel_name: str, f: NativeFunction, *, symint: bool) -> None:
  589. self.__schema = LazyIrSchema(f.func, symint=symint)
  590. self.__dispatch_args = ", ".join(
  591. [a.decl() for a in dispatcher.arguments(f.func, symint=symint)]
  592. )
  593. self.__call_args = ", ".join(
  594. [f"{arg.name}" for arg in self.__schema.filtered_args(generator=True)]
  595. )
  596. self.__kernel_name = kernel_name
  597. def __decl_suffix(self) -> str:
  598. return f"{self.__kernel_name}({self.__dispatch_args})"
  599. def __call_suffix(self) -> str:
  600. return f"{self.__kernel_name}({self.__call_args})"
  601. @property
  602. def shape_decl(self) -> str:
  603. return f"TORCH_API std::vector<torch::lazy::Shape> compute_shape_{self.__decl_suffix()}"
  604. @property
  605. def shape_call(self) -> str:
  606. return f"torch::lazy::compute_shape_{self.__call_suffix()}"
  607. @dataclass(frozen=True)
  608. class GenLazyShapeInferenceDefinition:
  609. backend_index: BackendIndex
  610. tensor_class: str
  611. @method_with_native_function
  612. def __call__(self, f: NativeFunction) -> list[str]:
  613. metadata = self.backend_index.get_kernel(f)
  614. if metadata is None:
  615. raise AssertionError(f"No kernel metadata found for {f.func.name}")
  616. # See Note [Generated LTC Shape Functions]
  617. is_view_copy_op = "view_copy" in f.tags
  618. is_structured = f.structured or f.structured_delegate is not None
  619. if is_structured or is_view_copy_op:
  620. return []
  621. else:
  622. shape_sig = ComputeShapeSignature(
  623. metadata.kernel, f, symint=metadata.supports_symint()
  624. )
  625. return ["\n".join([f"{shape_sig.shape_decl};"])]
  626. def generate_non_native_lazy_ir_nodes(
  627. non_native: list[dict[str, Any]], gen_lazy_ir: GenLazyIR
  628. ) -> list[str]:
  629. """Generate the non-native lazy IR node classes"""
  630. nodes = []
  631. for op in non_native:
  632. # Set default properties for Non-Native IRs
  633. properties = LazyIrProperties("ShapeCache", "CanBeReused", "LowerDeclOnly")
  634. for p in op.get("properties", []):
  635. setattr(properties, p, True)
  636. # non-native is assumed to want symint bindings if you wrote symint
  637. schema = LazyIrSchema(FunctionSchema.parse(op["func"]), properties, symint=True)
  638. schema.opkind = op.get("opkind")
  639. nodes.append(gen_lazy_ir.gen(schema)[0])
  640. return nodes