lazy_ts_lowering.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. from torchgen.api.lazy import LazyArgument, LazyIrSchema
  2. from torchgen.api.types import OptionalCType
  3. def ts_lowering_body(schema: LazyIrSchema) -> str:
  4. # for now, we just want one IR class decl and soon after also the method defs
  5. # and we use the functional version not out/inplace.
  6. emplace_arguments = []
  7. def get_value(arg: LazyArgument) -> str:
  8. if isinstance(arg.lazy_type, OptionalCType):
  9. return f"has_{arg.name} ? loctx->GetOutputOp(operand(i++)) : nullptr"
  10. return "loctx->GetOutputOp(operand(i++))"
  11. for arg in schema.positional_args:
  12. if arg.is_lazy_value:
  13. emplace_arguments.append(get_value(arg))
  14. continue
  15. emplace_arguments.append(f'"{arg.name}", {arg.name}')
  16. emplace_arguments_str = "\n ".join(
  17. [f"arguments.emplace_back({a});" for a in emplace_arguments]
  18. )
  19. emplace_kwarg_values = [
  20. f'"{arg.name}", {get_value(arg)}' for arg in schema.keyword_values
  21. ]
  22. emplace_kwarg_scalars = [
  23. f'"{arg.name}", {arg.name}' for arg in schema.keyword_scalars
  24. ]
  25. emplace_kwarguments = "\n ".join(
  26. [
  27. f"kwarguments.emplace_back({a});"
  28. for a in emplace_kwarg_values + emplace_kwarg_scalars
  29. ]
  30. )
  31. return f"""\
  32. std::vector<torch::jit::NamedValue> arguments;
  33. std::vector<torch::jit::NamedValue> kwarguments;
  34. arguments.reserve({len(emplace_arguments)});
  35. kwarguments.reserve({len(emplace_kwarg_values + emplace_kwarg_scalars)});
  36. size_t i = 0;
  37. {emplace_arguments_str}
  38. {emplace_kwarguments}
  39. torch::lazy::TSOpVector {schema.aten_name}_out = torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
  40. TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
  41. return {schema.aten_name}_out;
  42. """