translate.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. from __future__ import annotations
  2. from typing import NoReturn, TYPE_CHECKING
  3. from torchgen.api.types import (
  4. ArrayRefCType,
  5. BaseCType,
  6. Binding,
  7. boolT,
  8. ConstRefCType,
  9. deviceT,
  10. Expr,
  11. intArrayRefT,
  12. iOptTensorListRefT,
  13. layoutT,
  14. ListCType,
  15. longT,
  16. memoryFormatT,
  17. MutRefCType,
  18. NamedCType,
  19. opmath_t,
  20. OptionalCType,
  21. optionalIntArrayRefT,
  22. optionalScalarRefT,
  23. optionalSymIntArrayRefT,
  24. optionalTensorRefT,
  25. scalar_t,
  26. scalarT,
  27. scalarTypeT,
  28. SpecialArgName,
  29. symIntArrayRefT,
  30. SymIntT,
  31. tensorOptionsT,
  32. tensorT,
  33. VectorCType,
  34. )
  35. if TYPE_CHECKING:
  36. from collections.abc import Sequence
  37. # This file implements a small program synthesis engine that implements
  38. # conversions between one API to another.
  39. #
  40. # The key data type in this file in NamedCType, short for Named C++ semantic type. A NamedCType
  41. # represents a C++ type, plus semantic information about what it represents.
  42. # For example, consider the argument "bool pin_memory"; its normal C++ type is
  43. # "bool", but its C++ semantic type also keeps track that this represents a
  44. # "pin_memory"; you can't just use a random other boolean in a context where you
  45. # need a "pin_memory"!
  46. #
  47. # The translator takes a list of needed NamedCTypes, and then figures out how
  48. # to construct expressions with these NamedCTypes from the given bindings. Many
  49. # of these expressions are trivial (I need a Tensor other; there's a Tensor
  50. # other scope); others are more nontrivial and may require packing/unpacking.
  51. # Some examples of non-trivial action:
  52. #
  53. # - Need the "dtype" binding? Well, maybe "dtype" isn't available
  54. # in the context, instead, "options" is, and you need to extract
  55. # it from there. (Gather)
  56. #
  57. # - Need the "context" binding? Well, maybe "context" isn't available
  58. # in the context, and you need to construct it from "dtype", "device",
  59. # etc. (Scatter)
  60. #
  61. # - Need the "memory_format" binding? Well, actually, it's available
  62. # from both "memory_format" and "options", so you had better make sure
  63. # they are consistent. (Join)
  64. options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))
  65. out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT)))
  66. longVec_ctype = VectorCType(BaseCType(longT))
  67. longSymVec_ctype = VectorCType(BaseCType(SymIntT))
  68. optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
  69. optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
  70. optionalTensor_ctype = OptionalCType(BaseCType(tensorT))
  71. class UnsatError(RuntimeError):
  72. pass
  73. # Given a set of in-scope bindings and a set of target bindings, synthesize
  74. # a list of expressions that uses only the in-scope bindings (bindings) that
  75. # have all of the types of goals. You may want to use this function if
  76. # you're generating code for a function like:
  77. #
  78. # void f({args}) {
  79. # g({exprs}); // g is a different API
  80. # }
  81. #
  82. # and you need to generate "exprs".
  83. #
  84. # Typically, a list of Bindings is convenient to get (you usually call something
  85. # like arguments() to get them); but technically you only need less information:
  86. # for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for
  87. # 'goals', an (ordered) list of NamedCType goals is sufficient. If you are doing
  88. # something more complicated, e.g., tracking the set of bindings in a context,
  89. # you may find using these smaller types more convenient.
  90. def translate(
  91. bindings: Sequence[Expr | Binding],
  92. goals: Sequence[NamedCType | Binding],
  93. *,
  94. method: bool = False,
  95. allow_expensive_conversions: bool = False,
  96. ) -> list[Expr]:
  97. binding_exprs: list[Expr] = []
  98. for b in bindings:
  99. if isinstance(b, Binding):
  100. binding_exprs.append(
  101. Expr(
  102. expr=b.name,
  103. type=b.nctype,
  104. )
  105. )
  106. else:
  107. binding_exprs.append(b)
  108. goal_ctypes: list[NamedCType] = []
  109. for g in goals:
  110. if isinstance(g, Binding):
  111. goal_ctypes.append(g.nctype)
  112. else:
  113. goal_ctypes.append(g)
  114. # Add all the bindings to the context
  115. ctx: dict[NamedCType, str] = {}
  116. for b in binding_exprs:
  117. ctx[b.type] = b.expr
  118. # While we're at it, do some simple forward inference, looking through
  119. # constructors.
  120. #
  121. # NB: When should you do forward inference versus backward inference?
  122. # The general idea:
  123. #
  124. # - Backward inference WHEN the goal gets smaller
  125. # - Forward inference WHEN the hypothesis gets smaller
  126. #
  127. # This helps ensure termination: backward inference starts with a goal
  128. # and tries to make it simpler and simpler until it's trivial; if the
  129. # goal can grow in size, we blow up to a really huge goal size.
  130. # Similarly, with forward inference we take hypotheses and decompose
  131. # them into simpler hypotheses; if hypotheses could expand in size,
  132. # we also have potential nontermination. (In the code below, forward
  133. # inference is only ever carried out at a single step, but you could
  134. # imagine repeated application of forward inference being profitable.)
  135. #
  136. # A good starting point in the literature for exploring more about proof
  137. # search are these lecture notes
  138. # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf
  139. #
  140. # TODO: My kingdom for a pattern matcher
  141. # https://www.python.org/dev/peps/pep-0634/
  142. #
  143. # TODO: This could get us in recomputation trouble if b.expr is nontrivial.
  144. # Fix this by implementing some sort of sharing so that if multiple
  145. # goals share the same expression, we only compute it once. This seems
  146. # to matter in practice as compiler is often unwilling to CSE nontrivial
  147. # expressions like scalar.to<scalar_t>()
  148. t = b.type
  149. if (
  150. isinstance(t, ConstRefCType)
  151. and isinstance(t.elem, OptionalCType)
  152. and isinstance(t.elem.elem, BaseCType)
  153. and str(t.elem.elem.type) == "at::Tensor"
  154. ):
  155. ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = (
  156. f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())"
  157. )
  158. if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
  159. ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = (
  160. f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())"
  161. )
  162. if t.type == ConstRefCType(BaseCType(scalarT)):
  163. ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to<opmath_t>()"
  164. if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))):
  165. ctx[NamedCType(t.name, BaseCType(optionalScalarRefT))] = (
  166. f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())"
  167. )
  168. if t.type == BaseCType(scalar_t):
  169. ctx[NamedCType(t.name, BaseCType(opmath_t))] = (
  170. f"static_cast<opmath_t>({b.expr})"
  171. )
  172. # [Note: IOptTensorListRef]
  173. if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))):
  174. ctx[NamedCType(t.name, BaseCType(iOptTensorListRefT))] = (
  175. f"at::IOptTensorListRef({b.expr})"
  176. )
  177. # Add implicit bindings if the generated code is inside a Tensor method
  178. if method:
  179. ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = (
  180. "const_cast<Tensor&>(*this)"
  181. )
  182. ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = (
  183. "const_cast<Tensor&>(*this)"
  184. )
  185. # This is better! Byte-for-byte compat
  186. # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"
  187. def unsat(goal: NamedCType) -> NoReturn:
  188. ctx_desc = "\n".join(
  189. f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items()
  190. )
  191. raise UnsatError(
  192. f"""
  193. Failed to synthesize the expression "{goal.cpp_type()} {goal.name}".
  194. When I failed, the following bindings were available in the context:
  195. {ctx_desc}
  196. This probably means there is a missing rule in the rules of torchgen.api.translate.
  197. Check this module for more information.
  198. """
  199. )
  200. # A shitty backtracking search implementation. It's shitty because it
  201. # does backtracking via stack (bad idea!) and for the most part tries to
  202. # avoid backtracking. In particular, if
  203. # direct=True, we won't try to do any fancy synthesis, just trivial
  204. # conversions (e.g., "T a" is OK for "const T& a"). So all of the
  205. # existing rules in this function simply try to solve immediately,
  206. # and bail if things don't work out.
  207. def solve(goal: NamedCType, *, direct: bool) -> str:
  208. def direct_solve(goal: NamedCType) -> str:
  209. return solve(goal, direct=True)
  210. if goal in ctx:
  211. # Trivial
  212. return ctx[goal]
  213. # const & is satisfied with mutable &
  214. if isinstance(goal.type, ConstRefCType):
  215. try:
  216. # WARNING: not strictly decreasing; be careful not
  217. # to add a direct conversion that goes satisfies
  218. # mutable& with const&
  219. return solve(
  220. NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct
  221. )
  222. except UnsatError:
  223. pass
  224. # mutable & is satisfied with value
  225. if isinstance(goal.type, MutRefCType):
  226. try:
  227. return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
  228. except UnsatError:
  229. pass
  230. # TODO: These are referentially equal, shouldn't have to do this;
  231. # ensuring we don't use type synonym IntArrayRef in codegen would
  232. # help
  233. if goal.type == ArrayRefCType(BaseCType(longT)):
  234. return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct)
  235. if direct:
  236. unsat(goal)
  237. # For now, all of these rules are mutually exclusive.
  238. if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
  239. memory_format = direct_solve(
  240. NamedCType(
  241. SpecialArgName.possibly_redundant_memory_format,
  242. OptionalCType(BaseCType(memoryFormatT)),
  243. )
  244. )
  245. # No need to join "memory_format" and "options" if the target API takes "options" directly.
  246. # Otherwise it will cause the redundant memory_format error.
  247. if options_ctype in goal_ctypes:
  248. return memory_format
  249. try:
  250. options = direct_solve(options_ctype)
  251. return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
  252. except UnsatError:
  253. return memory_format
  254. elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
  255. dtype = direct_solve(
  256. NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))
  257. )
  258. pin_memory = direct_solve(
  259. NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))
  260. )
  261. device = direct_solve(
  262. NamedCType("device", OptionalCType(BaseCType(deviceT)))
  263. )
  264. layout = direct_solve(
  265. NamedCType("layout", OptionalCType(BaseCType(layoutT)))
  266. )
  267. return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})"
  268. elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
  269. try:
  270. options = direct_solve(options_ctype)
  271. return f"c10::optTypeMetaToScalarType({options}.dtype_opt())"
  272. except UnsatError:
  273. out_tensor = direct_solve(out_tensor_ctype)
  274. return f"{out_tensor}.scalar_type()"
  275. elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
  276. try:
  277. options = direct_solve(options_ctype)
  278. return f"{options}.layout_opt()"
  279. except UnsatError:
  280. out_tensor = direct_solve(out_tensor_ctype)
  281. return f"{out_tensor}.layout()"
  282. elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
  283. try:
  284. options = direct_solve(options_ctype)
  285. return f"{options}.device_opt()"
  286. except UnsatError:
  287. out_tensor = direct_solve(out_tensor_ctype)
  288. return f"{out_tensor}.device()"
  289. elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
  290. try:
  291. options = direct_solve(options_ctype)
  292. return f"{options}.pinned_memory_opt()"
  293. except UnsatError:
  294. # If we're calling a factory op from its out= variant,
  295. # We don't actually care about the value of pin_memory.
  296. out_tensor = direct_solve(out_tensor_ctype)
  297. return "::std::nullopt"
  298. # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
  299. elif goal.type == BaseCType(intArrayRefT):
  300. try:
  301. return direct_solve(NamedCType(goal.name, longVec_ctype))
  302. except UnsatError:
  303. # We can also go SymIntArrayRef -> IntArrayRef
  304. symIntArrayRef_type = direct_solve(
  305. NamedCType(goal.name, BaseCType(symIntArrayRefT))
  306. )
  307. return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})"
  308. elif goal.type == BaseCType(symIntArrayRefT):
  309. try:
  310. r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
  311. return f"c10::fromIntArrayRefSlow({r})"
  312. except UnsatError:
  313. return direct_solve(NamedCType(goal.name, longSymVec_ctype))
  314. elif goal.type == BaseCType(SymIntT):
  315. return direct_solve(NamedCType(goal.name, BaseCType(longT)))
  316. elif goal.type == OptionalCType(BaseCType(SymIntT)):
  317. argname = direct_solve(
  318. NamedCType(goal.name, OptionalCType(BaseCType(longT)))
  319. )
  320. return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt"
  321. elif goal.type == BaseCType(longT):
  322. symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
  323. return f"{symInt_type}.guard_int(__FILE__, __LINE__)"
  324. elif goal.type == OptionalCType(BaseCType(longT)):
  325. argname = direct_solve(
  326. NamedCType(goal.name, OptionalCType(BaseCType(SymIntT)))
  327. )
  328. return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt"
  329. elif goal.type == BaseCType(optionalIntArrayRefT):
  330. try:
  331. return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
  332. except UnsatError:
  333. argname = direct_solve(
  334. NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT))
  335. )
  336. return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt"
  337. elif goal.type == BaseCType(optionalSymIntArrayRefT):
  338. # TODO: You might also want to solve this from longSymVec_ctype or
  339. # an optional version of it
  340. argname = direct_solve(
  341. NamedCType(goal.name, BaseCType(optionalIntArrayRefT))
  342. )
  343. return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt"
  344. elif goal.type == BaseCType(optionalScalarRefT):
  345. return direct_solve(NamedCType(goal.name, optionalScalar_ctype))
  346. elif goal.type == BaseCType(optionalTensorRefT):
  347. return direct_solve(NamedCType(goal.name, optionalTensor_ctype))
  348. # Note [translation from C++ reference to value types]
  349. # The below cases are all for when we have an argument with a reference type,
  350. # and a corresponding goal with a value type.
  351. # These are needed when we populate the inputs to a lambda capture and we need
  352. # to guarantee the lifetime of each captured argument.
  353. # We guard it with an explicit kwarg because converting to a value type is expensive
  354. # (O(n)) to convert from IntArrayRef to vector<int>),
  355. # so the caller of translate() should be explicit that they need it.
  356. if allow_expensive_conversions:
  357. if goal.type == VectorCType(BaseCType(longT)):
  358. intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT))
  359. argname = direct_solve(intArrayRef_ctype)
  360. return f"{argname}.vec()"
  361. if goal.type == VectorCType(BaseCType(SymIntT)):
  362. symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT))
  363. argname = direct_solve(symIntArrayRef_ctype)
  364. return f"{argname}.vec()"
  365. elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
  366. optionalIntArrayRef_ctype = NamedCType(
  367. goal.name, BaseCType(optionalIntArrayRefT)
  368. )
  369. argname = direct_solve(optionalIntArrayRef_ctype)
  370. return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt"
  371. elif goal.type == OptionalCType(BaseCType(scalarT)):
  372. optionalScalarRef_ctype = NamedCType(
  373. goal.name, BaseCType(optionalScalarRefT)
  374. )
  375. argname = direct_solve(optionalScalarRef_ctype)
  376. return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
  377. elif goal.type == OptionalCType(BaseCType(scalarT)):
  378. optionalTensorRef_ctype = NamedCType(
  379. goal.name, BaseCType(optionalTensorRefT)
  380. )
  381. argname = direct_solve(optionalTensorRef_ctype)
  382. return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
  383. # Technically, we also need to handle cases of C++ containers holding reference types.
  384. # But there currently aren't any ops that require lambda capture codegen
  385. # With arguments like ::std::vector<IntArrayRef>.
  386. # If that changes, we'll have to add the translation here.
  387. # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor.
  388. # We could probably generalize this to non-tensor types too.
  389. if goal.type == MutRefCType(BaseCType(tensorT)):
  390. const_ref_tensor_ctype = NamedCType(
  391. goal.name, ConstRefCType(BaseCType(tensorT))
  392. )
  393. argname = direct_solve(const_ref_tensor_ctype)
  394. return f"const_cast<Tensor&>({argname})"
  395. unsat(goal)
  396. return [Expr(solve(g, direct=False), g) for g in goal_ctypes]