generator.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859
  1. from __future__ import annotations
  2. import json
  3. import logging
  4. import math
  5. from typing import TYPE_CHECKING
  6. import torchgen.api.cpp as cpp
  7. from torchgen.context import native_function_manager
  8. from torchgen.model import (
  9. Argument,
  10. BackendIndex,
  11. BaseTy,
  12. BaseType,
  13. FunctionSchema,
  14. NativeFunctionsGroup,
  15. NativeFunctionsViewGroup,
  16. OptionalType,
  17. SelfArgument,
  18. TensorOptionsArguments,
  19. Type,
  20. )
  21. from torchgen.static_runtime import config
  22. if TYPE_CHECKING:
  23. from collections.abc import Sequence
  24. logger: logging.Logger = logging.getLogger()
  25. def has_alias(
  26. arguments: Sequence[Argument | SelfArgument | TensorOptionsArguments],
  27. ) -> bool:
  28. for arg in arguments:
  29. annotation = getattr(arg, "annotation", None)
  30. if not annotation:
  31. continue
  32. alias_set = getattr(annotation, "alias_set", ())
  33. if alias_set:
  34. return True
  35. return False
  36. BLOCKED_OPS = frozenset(
  37. (
  38. # non cpu ops
  39. "sparse_sampled_addmm",
  40. "hspmm",
  41. "linalg_svdvals",
  42. # sparse ops
  43. "sspaddmm",
  44. "coalesce",
  45. "_indices",
  46. "indices",
  47. "_values",
  48. "values",
  49. "crow_indices",
  50. "col_indices",
  51. # deprecated ops
  52. "floor_divide",
  53. "ger",
  54. # buggy ops
  55. "conj_physical", # P495807361
  56. "binary_cross_entropy", # P496394764
  57. "arccosh",
  58. # uncommon ops
  59. "cholesky",
  60. "lu_solve",
  61. "linalg_cholesky",
  62. "linalg_householder_product",
  63. "linalg_ldl_solve",
  64. "_compute_linear_combination",
  65. # training related ops
  66. "_make_dual",
  67. # cannot call directly
  68. "_fw_primal",
  69. # no documentation
  70. "_index_reduce",
  71. # TODO: these ones got added recently and need manual inspection
  72. "_new_zeros_with_same_feature_meta",
  73. "_conj_physical",
  74. "binary_cross_entropy_with_logits",
  75. "bincount",
  76. "conv_tbc",
  77. "copy",
  78. "_copy_from",
  79. "_copy_from_and_resize",
  80. "count_nonzero",
  81. "cudnn_affine_grid_generator",
  82. "cudnn_affine_grid_generator_backward",
  83. "cudnn_grid_sampler",
  84. "diag_embed",
  85. "embedding",
  86. "embedding_dense_backward",
  87. "_embedding_bag_dense_backward",
  88. "_embedding_bag_per_sample_weights_backward",
  89. "grid_sampler_2d",
  90. "_grid_sampler_2d_cpu_fallback",
  91. "grid_sampler_3d",
  92. "isnan",
  93. "mkldnn_linear",
  94. "median",
  95. "nanmedian",
  96. "_sparse_sparse_matmul",
  97. "batch_norm_backward_elemt",
  98. "_euclidean_dist",
  99. "pixel_shuffle",
  100. "pixel_unshuffle",
  101. "channel_shuffle",
  102. "_reshape_nested_backward",
  103. "relu",
  104. "prelu",
  105. "celu",
  106. "slice_scatter",
  107. "select_scatter",
  108. "diagonal_scatter",
  109. "sum",
  110. "_mkldnn_transpose",
  111. "_nested_tensor_from_mask",
  112. "_nested_from_padded",
  113. "_nested_tensor_size",
  114. "_nested_from_padded_and_nested_example",
  115. "_standard_gamma_grad",
  116. "_dirichlet_grad",
  117. "native_norm",
  118. "_sparse_softmax",
  119. "_sparse_softmax_backward_data",
  120. "_sparse_log_softmax",
  121. "_sparse_log_softmax_backward_data",
  122. "zero",
  123. "_sparse_addmm",
  124. "sparse_mask",
  125. "_sparse_mask_projection",
  126. "_to_dense",
  127. "_coalesce",
  128. "_coalesced",
  129. "copy_sparse_to_sparse",
  130. "to_sparse",
  131. "to_sparse_csr",
  132. "to_sparse_csc",
  133. "to_mkldnn",
  134. "quantize_per_tensor_dynamic",
  135. "quantize_per_channel",
  136. "q_per_channel_scales",
  137. "q_per_channel_zero_points",
  138. "int_repr",
  139. "_make_per_channel_quantized_tensor",
  140. "set",
  141. "lift",
  142. "lift_fresh",
  143. "lift_fresh_copy",
  144. "masked_scatter",
  145. "_masked_softmax",
  146. "_masked_softmax_backward",
  147. "put",
  148. "index_reduce",
  149. "trace",
  150. "_cholesky_solve_helper",
  151. "dist",
  152. "max",
  153. "_torch_cuda_cu_linker_symbol_op",
  154. "glu_jvp",
  155. "glu_backward_jvp",
  156. "hardswish_backward",
  157. "rrelu_with_noise_backward",
  158. "mkldnn_adaptive_avg_pool2d_backward",
  159. "_adaptive_avg_pool2d_backward",
  160. "_adaptive_avg_pool3d_backward",
  161. "isinf",
  162. "linalg_lu_solve",
  163. "linalg_vecdot",
  164. "linalg_matrix_exp",
  165. "linalg_eigvalsh",
  166. "_test_warn_in_autograd",
  167. "_test_autograd_multiple_dispatch_view",
  168. "_test_autograd_multiple_dispatch_view_copy",
  169. "_segment_reduce",
  170. "_segment_reduce_backward",
  171. "_fw_primal_copy",
  172. "_make_dual_copy",
  173. "view_as_real_copy",
  174. "view_as_complex_copy",
  175. "_conj_copy",
  176. "_neg_view_copy",
  177. "diagonal_copy",
  178. "detach_copy",
  179. "squeeze_copy",
  180. "t_copy",
  181. "unsqueeze_copy",
  182. "_indices_copy",
  183. "_values_copy",
  184. "indices_copy",
  185. "values_copy",
  186. "crow_indices_copy",
  187. "col_indices_copy",
  188. "ccol_indices",
  189. "ccol_indices_copy",
  190. "row_indices",
  191. "row_indices_copy",
  192. "unfold_copy",
  193. "alias_copy",
  194. "_triton_multi_head_attention",
  195. "special_airy_ai",
  196. "special_bessel_j0",
  197. "special_bessel_j1",
  198. "special_bessel_y0",
  199. "special_bessel_y1",
  200. "special_chebyshev_polynomial_t",
  201. "special_chebyshev_polynomial_u",
  202. "special_chebyshev_polynomial_v",
  203. "special_chebyshev_polynomial_w",
  204. "special_hermite_polynomial_h",
  205. "special_hermite_polynomial_he",
  206. "special_laguerre_polynomial_l",
  207. "special_legendre_polynomial_p",
  208. "special_modified_bessel_i0",
  209. "special_modified_bessel_i1",
  210. "special_modified_bessel_k0",
  211. "special_modified_bessel_k1",
  212. "special_scaled_modified_bessel_k0",
  213. "special_scaled_modified_bessel_k1",
  214. "special_shifted_chebyshev_polynomial_t",
  215. "special_shifted_chebyshev_polynomial_u",
  216. "special_shifted_chebyshev_polynomial_v",
  217. "special_shifted_chebyshev_polynomial_w",
  218. "special_spherical_bessel_j0",
  219. "_foobar",
  220. "_nested_tensor_strides",
  221. "_nested_tensor_storage_offsets",
  222. "_nested_get_values", # no CPU backend
  223. "_nested_get_values_copy", # no CPU backend
  224. "_nested_view_from_jagged", # testing needs to be patched
  225. "_nested_view_from_jagged_copy", # testing needs to be patched
  226. "_nested_view_from_buffer", # testing needs to be patched
  227. "_nested_view_from_buffer_copy", # testing needs to be patched
  228. "_int_mm", # testing needs to be patched
  229. "_to_sparse_csc", # testing needs to be patched
  230. "_to_sparse_csr", # testing needs to be patched
  231. "segment_reduce", # testing needs to be patched
  232. )
  233. )
  234. def is_supported(g: NativeFunctionsGroup | NativeFunctionsViewGroup) -> bool:
  235. base_op_name = ""
  236. func = None
  237. if isinstance(g, NativeFunctionsViewGroup):
  238. base_op_name = g.view.root_name
  239. func = g.view.func
  240. else:
  241. base_op_name = g.out.func.name.name.base
  242. func = g.out.func
  243. if config.is_hand_written(g):
  244. logger.info("HAND WRITTEN: %s", base_op_name)
  245. return False
  246. if base_op_name in BLOCKED_OPS:
  247. logger.info("BLOCKED: %s", base_op_name)
  248. return False
  249. for arg in func.schema_order_arguments():
  250. maybe_method = ivalue_type_conversion_method(arg.type)
  251. if not maybe_method:
  252. # Type converting is unsupported yet.
  253. logger.info("NOT SUPPORTED TYPE CONVERTING: %s", func)
  254. return False
  255. if isinstance(g, NativeFunctionsViewGroup):
  256. # TODO: stop doing type tests by converting to C++ and then testing
  257. # the string, just test the dang thing directly
  258. if "at::Tensor" != cpp.returns_type(func.returns, symint=False).cpp_type():
  259. # Returns a non-Tensor value.
  260. logger.info("NON-TENSOR RET TYPE: %s", str(func))
  261. return False
  262. return True
  263. # For out variant ops, we need to check the arguments of its functional func.
  264. for arg in g.functional.func.schema_order_arguments():
  265. maybe_method = ivalue_type_conversion_method(arg.type)
  266. if not maybe_method:
  267. # Type converting is unsupported yet.
  268. logger.info("NOT SUPPORTED TYPE CONVERTING: %s", g.functional.func)
  269. return False
  270. if not g.structured:
  271. # In case of unstructured op, we check if it has out variant implementation.
  272. # The out variant implementation satisfies the minimum requirement that it has the output tensor as the last
  273. # parameter.
  274. if (
  275. not hasattr(g, "out")
  276. or not str(func).endswith("Tensor(a!) out) -> Tensor(a!)")
  277. or not str(func.name).endswith(".out")
  278. ):
  279. return False
  280. # TODO: stop type testing by converting to C++
  281. if "at::Tensor &" != cpp.returns_type(func.returns, symint=False).cpp_type():
  282. logger.info("NON_TENSOR RET TYPE: %s", func)
  283. return False
  284. if has_alias(func.arguments.non_out):
  285. # This op may create an alias of inputs.
  286. logger.info("INPUTS ALIAS: %s", base_op_name)
  287. return False
  288. return True
  289. def ivalue_type_conversion_method(
  290. arg_type: BaseType | OptionalType | Type,
  291. ) -> tuple[bool, str] | None:
  292. """
  293. Return the method call expression of `c10::ivalue' to convert its contained value to
  294. the expected value of `arg_type` type. For example, for `arg_type` == BaseTy.Tensor,
  295. this function returns ".toTensor()", so that it can be appended to the ivalue's
  296. variable name to get the value of the expected type.
  297. """
  298. type_conversion_methods = {
  299. BaseTy.Tensor: ((True, "toTensor()"), (False, "toOptional<at::Tensor>()")),
  300. BaseTy.int: ((False, "toInt()"), (False, "toOptional<int64_t>()")),
  301. BaseTy.bool: ((False, "toBool()"), (False, "toOptional<bool>()")),
  302. BaseTy.Scalar: ((False, "toScalar()"), (False, "toOptional<at::Scalar>()")),
  303. BaseTy.ScalarType: (
  304. (False, "toScalarType()"),
  305. (False, "toOptional<at::ScalarType>()"),
  306. ),
  307. BaseTy.str: (
  308. (False, "toStringView()"),
  309. (False, "toOptional<c10::string_view>()"),
  310. (False, "toOptional<::std::string_view>()"),
  311. ),
  312. }
  313. base_ty_object = None
  314. if isinstance(arg_type, BaseType):
  315. base_ty_object = arg_type.name
  316. elif isinstance(arg_type, OptionalType):
  317. if not isinstance(arg_type.elem, BaseType):
  318. # ListType is currently unsupported.
  319. return None
  320. base_ty_object = arg_type.elem.name
  321. else:
  322. return None
  323. if base_ty_object not in type_conversion_methods:
  324. return None
  325. methods = type_conversion_methods[base_ty_object]
  326. if isinstance(arg_type, BaseType):
  327. return methods[0]
  328. return methods[1]
  329. should_use_int_tensor_ops_ = frozenset(
  330. (
  331. "bitwise_not",
  332. "bitwise_and",
  333. "bitwise_or",
  334. "bitwise_xor",
  335. "bitwise_left_shift",
  336. "bitwise_right_shift",
  337. "gcd",
  338. "lcm",
  339. "scatter",
  340. "gather",
  341. "_convert_indices_from_coo_to_csr",
  342. "_convert_indices_from_csr_to_coo",
  343. )
  344. )
  345. should_use_complex_tensor_ops_ = frozenset(("view_as_real", "imag", "_conj"))
  346. def should_use_int_tensor(op_name: str) -> bool:
  347. return op_name in should_use_int_tensor_ops_
  348. def should_use_complex_tensor(op_name: str) -> bool:
  349. return op_name in should_use_complex_tensor_ops_
  350. test_tensor_dim_ops_1_ = frozenset(
  351. (
  352. "addmv",
  353. "index_add",
  354. "_convert_indices_from_coo_to_csr",
  355. "_convert_indices_from_csr_to_coo",
  356. "nll_loss_backward",
  357. "dot",
  358. "vdot",
  359. "outer",
  360. "ger",
  361. )
  362. )
  363. test_tensor_dim_ops_2_ = frozenset(
  364. ("addmm", "mm", "nuclear_norm", "diag", "_addmm_activation", "matrix_H", "t")
  365. )
  366. def test_tensor_dim(op_name: str) -> int:
  367. if op_name in test_tensor_dim_ops_1_:
  368. return 1
  369. if op_name in test_tensor_dim_ops_2_:
  370. return 2
  371. return 3
  372. test_tensor_shapes_string = '{"view_as_complex": "{2, 2}"}'
  373. test_tensor_shape_json: dict[str, str] = json.loads(test_tensor_shapes_string)
  374. def test_tensor_shape(op_name: str) -> str:
  375. if op_name in test_tensor_shape_json:
  376. return test_tensor_shape_json[op_name]
  377. else:
  378. return ""
  379. def test_value_expression(
  380. arg_type: BaseType | OptionalType | Type, index: int, op_name: str
  381. ) -> str:
  382. tensor_size_ex = test_tensor_shape(op_name)
  383. if tensor_size_ex == "":
  384. num_tensors = 16 if index == 0 else 64
  385. num_dim = test_tensor_dim(op_name)
  386. size_per_dim = math.ceil(num_tensors / float(num_dim))
  387. size_per_dim += size_per_dim % 2
  388. tensor_size_ex = "{{{}}}".format(",".join([f"{size_per_dim}"] * num_dim))
  389. if should_use_int_tensor(op_name):
  390. tensor_expression = f"at::randint(1, 100, {tensor_size_ex}, at::kInt)"
  391. elif should_use_complex_tensor(op_name):
  392. tensor_expression = f"at::randn({tensor_size_ex}, at::kComplexFloat)"
  393. else:
  394. tensor_expression = f"at::rand({tensor_size_ex})"
  395. value_expressions = {
  396. BaseTy.Tensor: tensor_expression,
  397. BaseTy.int: "1",
  398. BaseTy.bool: "false",
  399. BaseTy.Scalar: "2",
  400. BaseTy.ScalarType: "at::ScalarType::Float",
  401. BaseTy.str: '"floor"',
  402. }
  403. base_ty_object = None
  404. if isinstance(arg_type, BaseType):
  405. base_ty_object = arg_type.name
  406. else:
  407. if not (
  408. isinstance(arg_type, OptionalType) and isinstance(arg_type.elem, BaseType)
  409. ):
  410. raise AssertionError(
  411. f"Expected OptionalType with BaseType elem, got {type(arg_type)}"
  412. )
  413. base_ty_object = arg_type.elem.name
  414. if base_ty_object not in value_expressions:
  415. raise AssertionError(f"Unexpected type: {base_ty_object}")
  416. value_expression = value_expressions[base_ty_object]
  417. return value_expression
  418. def generate_test_value_definitions(schema: FunctionSchema, index: int) -> str:
  419. if schema.is_out_fn():
  420. raise AssertionError(f"Expected non-out function, got {schema}")
  421. schema_name = schema.name.name.base
  422. arg_map = {}
  423. for arg in schema.schema_order_arguments():
  424. test_value_exp = test_value_expression(arg.type, index, schema_name)
  425. arg_map[arg.name] = test_value_exp
  426. config.override_test_values(arg_map, schema_name, index)
  427. arg_populations = []
  428. for arg_name, arg_value in arg_map.items():
  429. arg_populations.append(f"auto {arg_name}{index} = {arg_value}")
  430. return ";\n ".join(arg_populations) + ";"
  431. def generate_test_value_names(schema: FunctionSchema, index: int) -> str:
  432. if schema.is_out_fn():
  433. raise AssertionError(f"Expected non-out function, got {schema}")
  434. return ",".join(f"{arg.name}{index}" for arg in schema.schema_order_arguments())
  435. generate_test_ir_arguments_base_ty_to_type_str_ = {
  436. BaseTy.Tensor: "Tensor",
  437. BaseTy.int: "int",
  438. BaseTy.float: "float",
  439. BaseTy.str: "str",
  440. BaseTy.Scalar: "int",
  441. BaseTy.ScalarType: "int",
  442. BaseTy.bool: "bool",
  443. }
  444. def generate_test_ir_arguments(
  445. schema: FunctionSchema,
  446. ) -> list[tuple[str, str | None]]:
  447. def ir_argument(arg: Argument) -> tuple[str, str | None]:
  448. t = arg.type
  449. add_optional = False
  450. if isinstance(t, OptionalType):
  451. t = t.elem
  452. add_optional = True
  453. if not isinstance(t, BaseType):
  454. raise AssertionError(f"Expected BaseType, got {type(t)}")
  455. type_str = None
  456. if t.name in generate_test_ir_arguments_base_ty_to_type_str_:
  457. type_str = generate_test_ir_arguments_base_ty_to_type_str_[t.name]
  458. if type_str and add_optional:
  459. type_str = f"{type_str}?"
  460. return ("%" + arg.name, type_str)
  461. return [ir_argument(arg) for arg in schema.schema_order_arguments()]
  462. def generate_arg_extraction(schema: FunctionSchema) -> str:
  463. arg_populations = []
  464. for i, arg in enumerate(schema.schema_order_arguments()):
  465. maybe_method = ivalue_type_conversion_method(arg.type)
  466. if not maybe_method:
  467. raise AssertionError(
  468. f"No type conversion method for {arg.name}: {arg.type}"
  469. )
  470. is_reference, type_conversion_method = maybe_method
  471. reference = "&" if is_reference else ""
  472. arg_populations.append(
  473. f"const auto{reference} {arg.name} = p_node->Input({i}).{type_conversion_method}"
  474. )
  475. return ";\n ".join(arg_populations) + ";"
  476. def get_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
  477. kernel = backend_index.get_kernel(g.functional)
  478. if g.structured or kernel is None:
  479. return cpp.name(g.functional.func)
  480. return kernel.kernel
  481. def get_out_kernel_name(g: NativeFunctionsGroup, backend_index: BackendIndex) -> str:
  482. kernel = backend_index.get_kernel(g.out)
  483. if g.structured or kernel is None:
  484. return cpp.name(g.out.func)
  485. return kernel.kernel
  486. def generate_non_out_variant_call(
  487. g: NativeFunctionsGroup, backend_index: BackendIndex
  488. ) -> str:
  489. schema = g.functional.func
  490. if schema.is_out_fn():
  491. raise AssertionError(f"Expected non-out function, got {schema}")
  492. kernel_name = get_kernel_name(g, backend_index)
  493. arg_names = (arg.name for arg in schema.schema_order_arguments())
  494. namespace_name = "cpu" if g.structured else "native"
  495. return f"at::{namespace_name}::{kernel_name}({','.join(arg_names)})"
  496. def generate_call_to_view_ops(
  497. g: NativeFunctionsViewGroup, backend_index: BackendIndex
  498. ) -> str:
  499. schema = g.view.func
  500. kernel_name = cpp.name(schema)
  501. kernel = backend_index.get_kernel(g.view)
  502. if kernel:
  503. kernel_name = kernel.kernel
  504. arg_names = (arg.name for arg in schema.schema_order_arguments())
  505. namespace_name = "native"
  506. return f"at::{namespace_name}::{kernel_name}({','.join(arg_names)})"
  507. def generate_out_variant_call(
  508. g: NativeFunctionsGroup, backend_index: BackendIndex
  509. ) -> str:
  510. schema = g.out.func
  511. if not schema.is_out_fn():
  512. raise AssertionError(f"Expected out function, got {schema}")
  513. arg_names = []
  514. kernel_name = get_out_kernel_name(g, backend_index)
  515. if g.structured:
  516. # structured op starts with the output tensor argument.
  517. arg_names = [out_arg.name for out_arg in schema.arguments.out]
  518. else:
  519. arg_names = []
  520. for arg in schema.arguments.non_out:
  521. if isinstance(arg, SelfArgument):
  522. arg_names.append(arg.argument.name)
  523. else:
  524. if not isinstance(arg, Argument):
  525. raise AssertionError(f"Expected Argument, got {type(arg)}")
  526. arg_names.append(arg.name)
  527. if not g.structured:
  528. if len(schema.arguments.out) != 1:
  529. raise AssertionError(
  530. f"Expected 1 out argument, got {len(schema.arguments.out)}"
  531. )
  532. arg_names.append(schema.arguments.out[0].name)
  533. cpp_arg_names = ",".join(arg_names)
  534. namespace_name = "cpu" if g.structured else "native"
  535. return f"at::{namespace_name}::{kernel_name}({cpp_arg_names})"
  536. no_memory_resize_ops = frozenset(
  537. (
  538. "isin.Scalar_Tensor",
  539. "index_add",
  540. "dot",
  541. "vdot",
  542. "nuclear_norm",
  543. "histc",
  544. "l1_loss",
  545. "multi_margin_loss",
  546. "multilabel_margin_loss",
  547. "nll_loss",
  548. "nll_loss2d",
  549. "prod",
  550. )
  551. )
  552. def should_check_resize(schema: FunctionSchema) -> bool:
  553. schema_str = str(schema)
  554. type_variant_op_name = schema_str[: schema_str.find("(")]
  555. return type_variant_op_name not in no_memory_resize_ops
  556. def op_name_from_group(g: NativeFunctionsGroup) -> str:
  557. return g.functional.func.name.name.base
  558. class GenOpDispatcher:
  559. def out_variant(
  560. self, groups: Sequence[NativeFunctionsGroup], backend_index: BackendIndex
  561. ) -> str:
  562. if not groups:
  563. return ""
  564. generated_type_variants = []
  565. for g in groups:
  566. with native_function_manager(g):
  567. if not is_supported(g):
  568. raise AssertionError(f"Unsupported function group: {g}")
  569. if not isinstance(g, NativeFunctionsGroup):
  570. raise AssertionError(
  571. f"Expected NativeFunctionsGroup, got {type(g)}"
  572. )
  573. generated_type_variant = self.out_variant_op_generator(g, backend_index)
  574. generated_type_variants.append(generated_type_variant)
  575. op_name = op_name_from_group(groups[0])
  576. body = "\n".join(generated_type_variants)
  577. generated = f"""
  578. REGISTER_OPERATOR_FUNCTOR(
  579. aten::{op_name},
  580. aten_{op_name},
  581. [](Node* n) -> SROperator {{
  582. {body}
  583. LogAndDumpSchema(n);
  584. return nullptr;
  585. }})
  586. """
  587. return generated
  588. def view(
  589. self, groups: Sequence[NativeFunctionsViewGroup], backend_index: BackendIndex
  590. ) -> str:
  591. if not groups:
  592. return ""
  593. generated_type_variants = []
  594. for g in groups:
  595. with native_function_manager(g):
  596. if not is_supported(g):
  597. raise AssertionError(f"Unsupported view group: {g}")
  598. if not isinstance(g, NativeFunctionsViewGroup):
  599. raise AssertionError(
  600. f"Expected NativeFunctionsViewGroup, got {type(g)}"
  601. )
  602. generated_type_variant = self.view_op_generator(g, backend_index)
  603. generated_type_variants.append(generated_type_variant)
  604. op_name = config.func_name_base_str(groups[0])
  605. body = "\n".join(generated_type_variants)
  606. generated = f"""
  607. REGISTER_NATIVE_OPERATOR_FUNCTOR(
  608. aten::{op_name},
  609. aten_{op_name},
  610. [](Node* n) -> SROperator {{
  611. {body}
  612. LogAndDumpSchema(n);
  613. return nullptr;
  614. }});
  615. """
  616. return generated
  617. def out_variant_op_generator(
  618. self, g: NativeFunctionsGroup, backend_index: BackendIndex
  619. ) -> str:
  620. functional = g.functional
  621. schema = str(functional.func)
  622. populated_argument = generate_arg_extraction(g.functional.func)
  623. functional_variant_call = generate_non_out_variant_call(g, backend_index)
  624. if len(g.out.func.arguments.out) != 1:
  625. raise AssertionError(
  626. f"Expected 1 out argument, got {len(g.out.func.arguments.out)}"
  627. )
  628. out_variable_name = str(g.out.func.arguments.out[0].name)
  629. out_variant_call = generate_out_variant_call(g, backend_index)
  630. generated = f"""
  631. if (n->matches(torch::schema("aten::{schema}"))) {{
  632. return [](ProcessedNode* p_node) {{
  633. {populated_argument}
  634. if (p_node->Output(0).isNone()) {{
  635. p_node->Output(0) = {functional_variant_call};
  636. return;
  637. }}
  638. auto& {out_variable_name} = p_node->Output(0).toTensor();
  639. fastResizeToZero({out_variable_name});
  640. {out_variant_call};
  641. }};
  642. }}"""
  643. return generated
  644. def view_op_generator(
  645. self, g: NativeFunctionsViewGroup, backend_index: BackendIndex
  646. ) -> str:
  647. schema = str(g.view.func)
  648. populated_argument = generate_arg_extraction(g.view.func)
  649. functional_variant_call = generate_call_to_view_ops(g, backend_index)
  650. generated = f"""
  651. if (n->matches(torch::schema("aten::{schema}"))) {{
  652. return [](ProcessedNode* p_node) {{
  653. {populated_argument}
  654. p_node->Output(0) = {functional_variant_call};
  655. }};
  656. }}"""
  657. return generated
  658. class GenOpTestCase:
  659. def out_variant(self, groups: Sequence[NativeFunctionsGroup]) -> str:
  660. if not groups:
  661. return ""
  662. generated_type_variants = []
  663. for g in groups:
  664. with native_function_manager(g):
  665. if not is_supported(g):
  666. raise AssertionError(f"Unsupported function group: {g}")
  667. if not isinstance(g, NativeFunctionsGroup):
  668. raise AssertionError(
  669. f"Expected NativeFunctionsGroup, got {type(g)}"
  670. )
  671. generated_type_variant = self.out_variant_op_test_case_generator(g)
  672. generated_type_variants.append(generated_type_variant)
  673. return "\n".join(generated_type_variants)
  674. def view(self, groups: Sequence[NativeFunctionsViewGroup]) -> str:
  675. if not groups:
  676. return ""
  677. generated_type_variants = []
  678. for g in groups:
  679. with native_function_manager(g):
  680. if not is_supported(g):
  681. raise AssertionError(f"Unsupported view group: {g}")
  682. if not isinstance(g, NativeFunctionsViewGroup):
  683. raise AssertionError(
  684. f"Expected NativeFunctionsViewGroup, got {type(g)}"
  685. )
  686. generated_type_variant = self.view_op_test_case_generator(g)
  687. generated_type_variants.append(generated_type_variant)
  688. return "\n".join(generated_type_variants)
  689. def out_variant_op_test_case_generator(self, g: NativeFunctionsGroup) -> str:
  690. schema = g.functional.func
  691. schema_str = str(schema)
  692. if schema_str.find("(") <= 0:
  693. raise AssertionError(f"Invalid schema string: {schema_str}")
  694. type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
  695. op_name = op_name_from_group(g)
  696. if not type_variant_op_name.startswith(op_name):
  697. raise AssertionError(
  698. f"Type variant op name {type_variant_op_name} doesn't start with {op_name}"
  699. )
  700. arg_types = generate_test_ir_arguments(schema)
  701. arg_declarations = ", ".join(
  702. (
  703. arg_name if arg_type is None else f"{arg_name}: {arg_type}"
  704. for arg_name, arg_type in arg_types
  705. )
  706. )
  707. arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
  708. if not (
  709. len(schema.returns) == 1
  710. and isinstance(schema.returns[0].type, BaseType)
  711. and schema.returns[0].type.name is BaseTy.Tensor
  712. ):
  713. raise AssertionError(f"Expected single Tensor return, got {schema.returns}")
  714. test_value_definitions = generate_test_value_definitions(schema, 0)
  715. test_value_names = generate_test_value_names(schema, 0)
  716. test_value_definitions2 = generate_test_value_definitions(schema, 1)
  717. test_value_names2 = generate_test_value_names(schema, 1)
  718. check_resize = "true" if should_check_resize(schema) else "false"
  719. generated = f"""
  720. TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
  721. const std::string script = R"IR(
  722. graph({arg_declarations}):
  723. %bias: None = prim::Constant()
  724. %ret = aten::{op_name}({arg_names})
  725. %cloned = aten::clone(%ret, %bias)
  726. return (%cloned)
  727. )IR";
  728. {test_value_definitions}
  729. std::vector<IValue> args{{{test_value_names}}};
  730. testStaticRuntime(script, args, {{}}, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
  731. {test_value_definitions2}
  732. std::vector<IValue> args2{{{test_value_names2}}};
  733. testStaticRuntime(script, args, args2, /*use_allclose=*/false, /*use_equalnan=*/false, /*check_resize=*/{check_resize});
  734. }}
  735. """
  736. return generated
  737. def view_op_test_case_generator(self, g: NativeFunctionsViewGroup) -> str:
  738. schema = g.view.func
  739. schema_str = str(schema)
  740. if schema_str.find("(") <= 0:
  741. raise AssertionError(f"Invalid schema string: {schema_str}")
  742. type_variant_op_name = schema_str[: schema_str.find("(")].replace(".", "_")
  743. op_name = g.view.root_name
  744. if not type_variant_op_name.startswith(op_name):
  745. raise AssertionError(
  746. f"Type variant op name {type_variant_op_name} doesn't start with {op_name}"
  747. )
  748. arg_types = generate_test_ir_arguments(schema)
  749. arg_declarations = ", ".join(
  750. (
  751. arg_name if arg_type is None else f"{arg_name}: {arg_type}"
  752. for arg_name, arg_type in arg_types
  753. )
  754. )
  755. arg_names = ", ".join((arg_name for arg_name, _ in arg_types))
  756. if not (
  757. len(schema.returns) == 1
  758. and isinstance(schema.returns[0].type, BaseType)
  759. and schema.returns[0].type.name is BaseTy.Tensor
  760. ):
  761. raise AssertionError(f"Expected single Tensor return, got {schema.returns}")
  762. test_value_definitions = generate_test_value_definitions(schema, 0)
  763. test_value_names = generate_test_value_names(schema, 0)
  764. generated = f"""
  765. TEST(StaticRuntime, autogen_{type_variant_op_name}) {{
  766. const std::string script = R"IR(
  767. graph({arg_declarations}):
  768. %bias: None = prim::Constant()
  769. %ret = aten::{op_name}({arg_names})
  770. %cloned = aten::clone(%ret, %bias)
  771. return (%cloned)
  772. )IR";
  773. {test_value_definitions}
  774. std::vector<IValue> args{{{test_value_names}}};
  775. testStaticRuntime(script, args);
  776. }}
  777. """
  778. return generated