cpp.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. from __future__ import annotations
  2. from typing import TYPE_CHECKING
  3. from typing_extensions import assert_never
  4. from torchgen import local
  5. from torchgen.api.types import (
  6. ArgName,
  7. ArrayCType,
  8. ArrayRefCType,
  9. BaseCType,
  10. BaseTypeToCppMapping,
  11. Binding,
  12. boolT,
  13. ConstRefCType,
  14. CType,
  15. dimnameListT,
  16. intArrayRefT,
  17. iTensorListRefT,
  18. ListCType,
  19. longT,
  20. MutRefCType,
  21. NamedCType,
  22. OptionalCType,
  23. optionalIntArrayRefT,
  24. optionalSymIntArrayRefT,
  25. scalarT,
  26. SpecialArgName,
  27. symIntArrayRefT,
  28. SymIntT,
  29. tensorListT,
  30. tensorOptionsT,
  31. tensorT,
  32. TupleCType,
  33. VectorCType,
  34. voidT,
  35. )
  36. from torchgen.model import (
  37. Argument,
  38. Arguments,
  39. BaseTy,
  40. BaseType,
  41. FunctionSchema,
  42. ListType,
  43. NativeFunction,
  44. OptionalType,
  45. Return,
  46. SelfArgument,
  47. TensorOptionsArguments,
  48. Type,
  49. )
  50. if TYPE_CHECKING:
  51. from collections.abc import Sequence
  52. # This file describes the translation of JIT schema to the public C++
  53. # API, which is what people use when they call functions like at::add.
  54. #
  55. # Prominent characteristics of the C++ API:
  56. #
  57. # - dtype, layout, device and pin_memory are collected into
  58. # a single C++ type TensorOptions (the native functions API
  59. # also has this, but tensor options is really most relevant
  60. # for the C++ API; it makes calling kwarg factory functions
  61. # pleasant)
  62. #
  63. # - defaulting lives here (in fact, the dispatcher is completely
  64. # oblivious of defaults!)
  65. #
  66. # BTW: policy on name collisions: we try not to have types with
  67. # collisions, but functions are fair game to collide
  68. def name(
  69. func: FunctionSchema,
  70. *,
  71. faithful_name_for_out_overloads: bool = False,
  72. symint_overload: bool = False,
  73. ) -> str:
  74. name = str(func.name.name)
  75. if symint_overload:
  76. name += "_symint"
  77. if func.is_out_fn():
  78. if faithful_name_for_out_overloads:
  79. name += "_outf"
  80. else:
  81. name += "_out"
  82. return name
  83. # Translation of "value types" in JIT schema to C++ API type. Value
  84. # types look the same no matter if they are argument types or return
  85. # types. Returns None if the type in question is not a value type.
  86. def valuetype_type(
  87. t: Type,
  88. *,
  89. binds: ArgName,
  90. mutable: bool = True,
  91. symint: bool = False,
  92. ) -> NamedCType | None:
  93. if isinstance(t, BaseType):
  94. if t.name in (BaseTy.Tensor, BaseTy.Scalar):
  95. return None
  96. elif str(t) == "SymInt":
  97. if symint:
  98. return NamedCType(binds, BaseCType(SymIntT))
  99. else:
  100. return NamedCType(binds, BaseCType(longT))
  101. # All other BaseType currently map directly to BaseCppTypes.
  102. return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
  103. elif isinstance(t, OptionalType):
  104. elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint)
  105. if elem is None:
  106. return None
  107. return NamedCType(binds, OptionalCType(elem.type))
  108. elif isinstance(t, ListType):
  109. if str(t.elem) == "bool":
  110. if t.size is None:
  111. raise AssertionError("bool ListType must have a size")
  112. return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
  113. else:
  114. return None
  115. else:
  116. raise AssertionError(f"unrecognized type {repr(t)}")
  117. # Translation of types occurring in JIT arguments to a C++ argument type.
  118. # If remove_non_owning_ref_types is set, we'll guarantee that the output CType is not a non-owning reference type.
  119. # For example, we'll return std::vector<int> instead of IntArrayRef.
  120. # See Note [translation from C++ reference to value types]
  121. def argumenttype_type(
  122. t: Type,
  123. *,
  124. mutable: bool,
  125. binds: ArgName,
  126. remove_non_owning_ref_types: bool = False,
  127. symint: bool = False,
  128. ) -> NamedCType:
  129. # If it's a value type, do the value type translation
  130. r = valuetype_type(
  131. t,
  132. binds=binds,
  133. mutable=mutable,
  134. symint=symint,
  135. )
  136. if r is not None:
  137. return r
  138. if isinstance(t, BaseType):
  139. if t.name == BaseTy.Tensor:
  140. if mutable and not local.use_const_ref_for_mutable_tensors():
  141. return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
  142. else:
  143. return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
  144. elif t.name == BaseTy.Scalar:
  145. return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
  146. else:
  147. raise AssertionError(f"base type should have been value type {t}")
  148. elif isinstance(t, OptionalType):
  149. if str(t.elem) == "Tensor":
  150. if mutable and not local.use_const_ref_for_mutable_tensors():
  151. return NamedCType(
  152. binds, MutRefCType(BaseCType(tensorT))
  153. ) # TODO: fix this discrepancy
  154. else:
  155. return NamedCType(
  156. binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
  157. )
  158. elif str(t.elem) == "Scalar":
  159. return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
  160. elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
  161. return NamedCType(binds, BaseCType(optionalIntArrayRefT))
  162. elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt":
  163. if symint:
  164. return NamedCType(binds, BaseCType(optionalSymIntArrayRefT))
  165. else:
  166. return NamedCType(binds, BaseCType(optionalIntArrayRefT))
  167. elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
  168. return NamedCType(binds, OptionalCType(elem.type))
  169. elif isinstance(t, ListType):
  170. # TODO: remove these special cases, ArrayRef fallthrough works fine
  171. if str(t.elem) == "int":
  172. if remove_non_owning_ref_types:
  173. return NamedCType(binds, VectorCType(BaseCType(longT)))
  174. else:
  175. return NamedCType(binds, BaseCType(intArrayRefT))
  176. if str(t.elem) == "SymInt":
  177. if remove_non_owning_ref_types:
  178. if symint:
  179. return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
  180. else:
  181. return NamedCType(binds, VectorCType(BaseCType(longT)))
  182. else:
  183. if symint:
  184. return NamedCType(binds, BaseCType(symIntArrayRefT))
  185. else:
  186. return NamedCType(binds, BaseCType(intArrayRefT))
  187. if str(t.elem) == "Tensor":
  188. if local.use_ilistref_for_tensor_lists():
  189. return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
  190. else:
  191. return NamedCType(binds, BaseCType(tensorListT))
  192. elif str(t.elem) == "Scalar":
  193. return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
  194. elif str(t.elem) == "Dimname":
  195. return NamedCType(binds, BaseCType(dimnameListT))
  196. elif str(t.elem) == "Tensor?":
  197. return NamedCType(
  198. binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
  199. )
  200. elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
  201. return NamedCType(binds, ArrayRefCType(elem.type))
  202. else:
  203. raise AssertionError(f"unrecognized type {repr(t)}")
  204. # Translate a JIT argument into its C++ type
  205. def argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
  206. return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
  207. # Translation of a (non-multi) return type from JIT to C++
  208. # N.B: returntype_type returns a CType, not a NamedCType.
  209. # This is mostly because of the mismatch between return types and return names.
  210. # e.g. a function with a return type of 'void' has 0 return names,
  211. # and a function with a return type of 'std::tuple' has >1 return name.
  212. def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
  213. # placeholder is ignored
  214. # NB: symint is ALWAYS respected for return types. So symint argument
  215. # here is IGNORED
  216. r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True)
  217. if r is not None:
  218. return r.type
  219. if isinstance(t, BaseType):
  220. if t.name == BaseTy.Tensor:
  221. if mutable:
  222. if local.use_const_ref_for_mutable_tensors():
  223. return ConstRefCType(BaseCType(tensorT))
  224. else:
  225. return MutRefCType(BaseCType(tensorT))
  226. else:
  227. # Note [Tensor Copy Returns]
  228. # Currently, we use "Argument.is_write" to determine
  229. # whether or not Tensor return types should be copies or references.
  230. # If that ever changes, take a look at other locations of this note!
  231. return BaseCType(tensorT)
  232. elif t.name == BaseTy.Scalar:
  233. return BaseCType(scalarT)
  234. elif isinstance(t, ListType):
  235. if mutable:
  236. raise AssertionError(
  237. "Native functions should never return a mutable tensor list. "
  238. "They should return void."
  239. )
  240. elem = returntype_type(t.elem, mutable=False)
  241. if t.size is not None:
  242. raise AssertionError(f"fixed size list returns not supported: {t}")
  243. return VectorCType(elem)
  244. elif isinstance(t, OptionalType):
  245. elem = returntype_type(t.elem, mutable=mutable)
  246. if str(t.elem) == "Tensor":
  247. return OptionalCType(elem)
  248. raise AssertionError(f"unrecognized return type {t}")
  249. # Translation of a single return to its C++ type
  250. def return_type(r: Return, *, symint: bool = False) -> CType:
  251. return returntype_type(r.type, mutable=r.is_write, symint=symint)
  252. # Translation of a full (possibly multi) return from JIT to its C++ type
  253. def returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
  254. if len(rs) == 0:
  255. return BaseCType(voidT)
  256. elif len(rs) == 1:
  257. return return_type(rs[0], symint=symint)
  258. else:
  259. return TupleCType([return_type(r, symint=symint) for r in rs])
  260. def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
  261. returns: list[str] = []
  262. for i, r in enumerate(f.func.returns):
  263. # If we have an inplace function, the return argument is
  264. # implicitly named self.
  265. # TODO: Consider incorporating this into the data model
  266. if f.func.name.name.inplace:
  267. if i != 0:
  268. raise AssertionError("illegal inplace function with multiple returns")
  269. name = "self"
  270. # If we are out function, the name is the name of the
  271. # corresponding output function (r.name will get recorded
  272. # in field_name later.)
  273. elif f.func.is_out_fn():
  274. name = f.func.arguments.out[i].name
  275. # If the return argument is explicitly named...
  276. elif r.name:
  277. name_conflict = any(
  278. r.name == a.name for a in f.func.schema_order_arguments()
  279. )
  280. if name_conflict and not f.func.is_out_fn():
  281. name = f"{r.name}_return"
  282. else:
  283. name = r.name
  284. # If there is no explicit name and no fallback name was passed in, we just name the output result,
  285. # unless it's a multi-return, in which case it's result0,
  286. # result1, etc (zero-indexed)
  287. else:
  288. name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
  289. returns.append(name)
  290. return returns
  291. JIT_TO_CPP_DEFAULT = {
  292. "False": "false",
  293. "True": "true",
  294. "None": "::std::nullopt", # UGH this one is type directed
  295. "Mean": "at::Reduction::Mean",
  296. "[]": "{}",
  297. "contiguous_format": "c10::MemoryFormat::Contiguous",
  298. "long": "at::kLong",
  299. }
  300. # Convert a JIT default into C++ expression representing the default
  301. def default_expr(d: str, t: Type, *, symint: bool) -> str:
  302. if d == "None" and str(t) == "Tensor?":
  303. return "{}"
  304. if isinstance(t, BaseType) and t.name is BaseTy.str:
  305. # Schema allows single quotes but C++ needs double
  306. if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
  307. s = ""
  308. i = 1
  309. while i + 1 < len(d):
  310. if d[i] != "\\":
  311. if d[i] == '"':
  312. s += '\\"'
  313. else:
  314. s += d[i]
  315. i += 1
  316. else:
  317. if d[i + 1] == "'":
  318. s += "'"
  319. else:
  320. s += d[i : i + 2]
  321. i += 2
  322. return f'"{s}"'
  323. if isinstance(t, OptionalType):
  324. if d == "None":
  325. return "::std::nullopt"
  326. return default_expr(d, t.elem, symint=symint)
  327. if isinstance(t, ListType):
  328. if d.startswith("[") and d.endswith("]"):
  329. return "{" + d[1:-1] + "}"
  330. elif symint and d.isdigit() and str(t.elem) == "SymInt":
  331. return f"c10::SymInt({d})"
  332. elif t.size is None:
  333. # NOTE: Sized lists can have scalar defaults
  334. raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
  335. return JIT_TO_CPP_DEFAULT.get(d, d)
  336. # Convert an argument into its C++ API form
  337. def argument(
  338. a: Argument | TensorOptionsArguments | SelfArgument,
  339. *,
  340. cpp_no_default_args: set[str],
  341. method: bool,
  342. faithful: bool,
  343. symint: bool = False,
  344. has_tensor_options: bool,
  345. ) -> list[Binding]:
  346. def sub_argument(
  347. a: Argument | TensorOptionsArguments | SelfArgument,
  348. ) -> list[Binding]:
  349. return argument(
  350. a,
  351. cpp_no_default_args=cpp_no_default_args,
  352. method=method,
  353. faithful=faithful,
  354. symint=symint,
  355. has_tensor_options=has_tensor_options,
  356. )
  357. if isinstance(a, Argument):
  358. binds: ArgName
  359. if a.name == "memory_format" and has_tensor_options:
  360. binds = SpecialArgName.possibly_redundant_memory_format
  361. else:
  362. binds = a.name
  363. default: str | None = None
  364. if a.name not in cpp_no_default_args and a.default is not None:
  365. default = default_expr(a.default, a.type, symint=symint)
  366. return [
  367. Binding(
  368. nctype=argument_type(a, binds=binds, symint=symint),
  369. name=a.name,
  370. default=default,
  371. argument=a,
  372. )
  373. ]
  374. elif isinstance(a, TensorOptionsArguments):
  375. if faithful:
  376. return (
  377. sub_argument(a.dtype)
  378. + sub_argument(a.layout)
  379. + sub_argument(a.device)
  380. + sub_argument(a.pin_memory)
  381. )
  382. else:
  383. default = None
  384. # Enforced by NativeFunction.__post_init__
  385. if "options" in cpp_no_default_args:
  386. raise AssertionError("'options' should not be in cpp_no_default_args")
  387. if all(x.default == "None" for x in a.all()):
  388. default = "{}"
  389. elif a.dtype.default == "long":
  390. default = "at::kLong" # TODO: this is wrong
  391. return [
  392. Binding(
  393. nctype=NamedCType("options", BaseCType(tensorOptionsT)),
  394. name="options",
  395. default=default,
  396. argument=a,
  397. )
  398. ]
  399. elif isinstance(a, SelfArgument):
  400. if method:
  401. # Caller is responsible for installing implicit this in context!
  402. return []
  403. else:
  404. return sub_argument(a.argument)
  405. else:
  406. assert_never(a)
  407. def arguments(
  408. arguments: Arguments,
  409. *,
  410. faithful: bool,
  411. symint: bool = False,
  412. method: bool,
  413. cpp_no_default_args: set[str],
  414. ) -> list[Binding]:
  415. args: list[Argument | TensorOptionsArguments | SelfArgument] = []
  416. if faithful:
  417. args.extend(arguments.non_out)
  418. args.extend(arguments.out)
  419. else:
  420. args.extend(arguments.out)
  421. args.extend(arguments.non_out)
  422. return [
  423. r.no_default() if faithful else r
  424. for a in args
  425. for r in argument(
  426. a,
  427. faithful=faithful,
  428. symint=symint,
  429. method=method,
  430. has_tensor_options=arguments.tensor_options is not None,
  431. cpp_no_default_args=cpp_no_default_args,
  432. )
  433. ]