python.py 58 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. from typing import TYPE_CHECKING
  4. from torchgen.api import cpp
  5. from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
  6. from torchgen.gen import pythonify_default
  7. from torchgen.model import (
  8. Argument,
  9. BaseTy,
  10. BaseType,
  11. FunctionSchema,
  12. ListType,
  13. NativeFunction,
  14. OptionalType,
  15. Return,
  16. Type,
  17. Variant,
  18. )
  19. if TYPE_CHECKING:
  20. from collections.abc import Iterable, Sequence
  21. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  22. #
  23. # Data Models
  24. #
  25. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  26. #
  27. # [Notes] python binding codegen
  28. #
  29. # The Python binding codegen produces code that takes the input list of
  30. # PyObjects, finds the matching ATen C++ function using PythonArgParser,
  31. # converts the PyObjects into C++ types and calls the ATen C++ function:
  32. #
  33. # +--------+ parsing +------------------------+ binding +-----------------------+
  34. # | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch |
  35. # +--------+ +------------------------+ +-----------------------+
  36. #
  37. # The following examples demonstrate the data models the Python binding
  38. # codegen needs to deal with and the tasks it needs to accomplish. It
  39. # helps understand the purpose of the new data types we introduced below.
  40. #
  41. # - Function Schema (source of truth)
  42. #
  43. # aten::empty.names(int[] size, *, Dimname[]? names,
  44. # ScalarType? dtype=None, Layout? layout=None,
  45. # Device? device=None, bool? pin_memory=None,
  46. # MemoryFormat? memory_format=None) -> Tensor
  47. #
  48. # - Python Signature
  49. #
  50. # It's used to generate input schema string for PythonArgParser.
  51. # Note: TensorOptions fields are reordered and the additional
  52. # 'requires_grad' field is added:
  53. #
  54. # empty(IntArrayRef size, *, DimnameList? names,
  55. # MemoryFormat? memory_format=None, ScalarType dtype=None,
  56. # Layout layout=torch.strided, Device device=None,
  57. # bool pin_memory=False, bool requires_grad=False)
  58. #
  59. # - C++ Signature
  60. #
  61. # It's used to generate C++ lambda formals & dispatch call.
  62. # Note: the scattered TensorOptions fields are packed into 'options'.
  63. #
  64. # auto dispatch_empty =
  65. # [](IntArrayRef size, std::optional<DimnameList> names,
  66. # const TensorOptions & options,
  67. # std::optional<MemoryFormat> memory_format) -> Tensor {
  68. # pybind11::gil_scoped_release no_gil;
  69. # return torch::empty(size, names, options, memory_format);
  70. # };
  71. #
  72. # - Binding between Python Arguments and C++ Arguments
  73. #
  74. # Given a set of Python Arguments in scope, we need produce the
  75. # binding expressions that translate the Python API into C++ API:
  76. #
  77. # Python Args Cpp Args Binding Exprs
  78. # -----------------------------------------------------------------
  79. # 0: size size '_r.intlist(0)'
  80. # 1: names names 'names' [special init]
  81. # 2: memory_format -------+
  82. # 3: dtype -----+-|--> options 'options' [special packing]
  83. # 4: layout / |
  84. # 5: device / +--> memory_format '_r.memoryformatOptional(2)'
  85. # 6: pin_memory /
  86. # 7: requires_grad -+
  87. #
  88. # So the full dispatch expression would look like:
  89. #
  90. # dispatch_empty(_r.intlist(0), names, options,
  91. # _r.memoryformatOptional(2))
  92. #
  93. # Where does 'names' come from? It involves special local init:
  94. #
  95. # auto __names = _r.toDimnameListOptional(1);
  96. # std::optional<DimnameList> names =
  97. # __names ? std::make_optional(DimnameList(__names.value()))
  98. # : std::nullopt;
  99. #
  100. # Where does 'options' come from? It involves special local init
  101. # for TensorOptions. Note that Python side has the additional
  102. # 'requires_grad' field:
  103. #
  104. # const auto options = TensorOptions()
  105. # .dtype(_r.scalartype(3))
  106. # .device(_r.device(5))
  107. # .layout(_r.layoutOptional(4))
  108. # .requires_grad(_r.toBool(7))
  109. # .pinned_memory(_r.toBool(6));
  110. #
  111. # In some other cases one Python Argument can map to multiple C++
  112. # Arguments. For example:
  113. #
  114. # aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False)
  115. # -> (Tensor values, Tensor indices)
  116. #
  117. # Python Args Cpp Args Binding Exprs
  118. # ---------------------------------------------------------------------
  119. # +----> max 'out[0]'
  120. # /-----> max_values 'out[1]
  121. # 0: input / self '_r.tensor(0)'
  122. # 1: dim / dim '_r.dimname(1)'
  123. # 2: keepdim / keepdim '_r.toBool(2)'
  124. # 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)'
  125. #
  126. # As demonstrated above, the binding can involve reordering,
  127. # packing, unpacking and special local inits.
  128. #
  129. #
  130. # Let's look at a concrete example:
  131. #
  132. # static PythonArgParser parser({
  133. # "abs(Tensor input, *, Tensor out=None)",
  134. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  135. # ^
  136. # +--- Python Schema, represented by PythonSignature and PythonArgument
  137. #
  138. # }, /*traceable=*/true);
  139. #
  140. # ParsedArgs<2> parsed_args;
  141. # auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
  142. #
  143. # ...
  144. #
  145. # if (_r.isNone(1)) {
  146. # ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out')
  147. # represented by PythonArgParserOutputExpr
  148. #
  149. # // aten::abs(Tensor self) -> Tensor
  150. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  151. # ^
  152. # +--- NativeFunction schema, base version
  153. #
  154. # auto dispatch_abs = [](const Tensor & self) -> Tensor {
  155. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  156. # ^
  157. # +--- dispatch_lambda_args / dispatch_lambda_return_str
  158. # generated from NativeFunction / CppSignature
  159. # (deprecated PythonSignature is special)
  160. # arguments are represented by DispatchLambdaArgument
  161. #
  162. # pybind11::gil_scoped_release no_gil;
  163. # return self.abs();
  164. # ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs
  165. # generated from NativeFunction / CppSignature
  166. # };
  167. # return wrap(dispatch_abs(_r.tensor(0)));
  168. # ~~~~~~~~~~~~~
  169. # ^
  170. # +--- dispatch_lambda_exprs
  171. # binding PythonArgParserOutputExpr (python args)
  172. # and DispatchLambdaArgument (c++ args)
  173. #
  174. # } else {
  175. # // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
  176. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  177. # ^
  178. # +--- NativeFunction schema, out-variant
  179. #
  180. # auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor {
  181. # pybind11::gil_scoped_release no_gil;
  182. # return at::abs_out(out, self);
  183. # };
  184. # return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0)));
  185. # }
  186. #
  187. #
  188. # [Notes] python interface codegen
  189. # The python dataclasses below are used used to generate both python binding code
  190. # and pyi type hint signatures.
  191. # In theory these two should look very similar, but there are number of differences
  192. # in how pyi signatures vs. python_arg_parser signatures are generated.
  193. # These differences have been encapsulated in signature_str() vs. signature_str_pyi()
  194. # to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments.
  195. # For examples, only pyi signatures include return types.
  196. def format_function_signature(
  197. name: str, arguments: Iterable[str] = (), return_type: str | None = None
  198. ) -> str:
  199. if not isinstance(arguments, (list, tuple)):
  200. arguments = tuple(arguments)
  201. return_type = f" -> {return_type}" if return_type is not None else ""
  202. sig = f"def {name}({', '.join(arguments)}){return_type}: ..."
  203. if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",):
  204. return sig
  205. lines = [
  206. f"def {name}(",
  207. *(f" {arg}," for arg in arguments),
  208. f"){return_type}: ...",
  209. ]
  210. sig = "\n".join(lines)
  211. if all(len(line) <= 80 for line in lines):
  212. return sig
  213. # ruff format bug for compound statements: https://github.com/astral-sh/ruff/issues/18658
  214. # use `skip` instead of `on` + `off`
  215. return sig.removesuffix(" ...") + " # fmt: skip\n ..."
  216. @dataclass(frozen=True)
  217. class PythonReturns:
  218. returns: tuple[Return, ...]
  219. @dataclass(frozen=True)
  220. class PythonArgument:
  221. name: str
  222. type: Type
  223. default: str | None
  224. # Used to generate the default init expr for some PythonArgParser outputs, e.g.:
  225. #
  226. # _r.layoutWithDefault(3, layout_from_backend(self.options().backend())))
  227. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
  228. # ^
  229. # +--- default_init str
  230. default_init: str | None
  231. # Compute argument formal for python argument parsing.
  232. # Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
  233. def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
  234. type_str = (
  235. argument_type_str(self.type, symint=symint)
  236. .replace("const ", "")
  237. .replace(" &", "")
  238. )
  239. name = self.name
  240. # s/self/input/ outside method bindings
  241. # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
  242. # for the parse string
  243. if name == "self" and type_str in ["Tensor", "Number"] and not method:
  244. name = "input"
  245. # add default
  246. if self.default is not None:
  247. default = {
  248. "nullptr": "None",
  249. "::std::nullopt": "None",
  250. "std::nullopt": "None",
  251. "{}": "None",
  252. }.get(self.default, self.default)
  253. return f"{type_str} {name}={default}"
  254. else:
  255. return f"{type_str} {name}"
  256. def argument_str_pyi(
  257. self, *, method: bool = False, deprecated: bool = False
  258. ) -> str:
  259. type_str = argument_type_str_pyi(self.type)
  260. name = self.name
  261. # s/self/input/ outside method bindings
  262. # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
  263. # for the parse string
  264. if name == "self" and type_str == "Tensor" and not method and not deprecated:
  265. name = "input"
  266. if name == "from": # from is a Python keyword...
  267. name += "_"
  268. # pyi merges the _out and functional variants into the same signature, with an optional out arg
  269. if name == "out" and type_str == "Tensor" and not deprecated:
  270. type_str = f"{type_str} | None".replace(" | None | None", " | None")
  271. # pyi deprecated signatures don't get defaults for their out arg
  272. treat_as_no_default = (
  273. deprecated
  274. and isinstance(self, PythonOutArgument)
  275. and self.default == "None"
  276. )
  277. # add default
  278. if self.default is not None and not treat_as_no_default:
  279. if (
  280. isinstance(self.type, ListType)
  281. and self.type.elem == BaseType(BaseTy.int)
  282. and self.default.startswith("{")
  283. and self.default.endswith("}")
  284. ):
  285. default = (
  286. "(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")"
  287. )
  288. else:
  289. default = {
  290. "nullptr": "None",
  291. "::std::nullopt": "None",
  292. "std::nullopt": "None",
  293. "{}": "None",
  294. "c10::MemoryFormat::Contiguous": "contiguous_format",
  295. "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
  296. }.get(self.default, self.default)
  297. return f"{name}: {type_str} = {default}"
  298. else:
  299. return f"{name}: {type_str}"
  300. @dataclass(frozen=True)
  301. class PythonOutArgument(PythonArgument):
  302. # In Python signature multiple output fields are packed into one 'out' argument.
  303. # When binding to C++, it's first binded to a local 'out' variable:
  304. # 'auto out = _r.tensorlist_n<2>(2);',
  305. # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
  306. # TODO: maybe don't need keep scattered out fields for python signature?
  307. outputs: tuple[PythonArgument, ...]
  308. @staticmethod
  309. def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
  310. if not outputs:
  311. return None
  312. size = len(outputs)
  313. if size == 1:
  314. return PythonOutArgument(
  315. name=outputs[0].name,
  316. type=outputs[0].type,
  317. default="None",
  318. default_init=None,
  319. outputs=outputs,
  320. )
  321. elif size > 1:
  322. if any(not a.type.is_tensor_like() for a in outputs):
  323. raise RuntimeError(f"Unsupported output type: {outputs}")
  324. return PythonOutArgument(
  325. name="out",
  326. # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
  327. type=ListType(BaseType(BaseTy.Tensor), size),
  328. default="None",
  329. default_init=None,
  330. outputs=outputs,
  331. )
  332. raise AssertionError(r"Unexpected PythonOutArgument size")
  333. @dataclass(frozen=True)
  334. class PythonSignature:
  335. # Base operator name, without inplace/outplace suffix.
  336. name: str
  337. # Positional arguments.
  338. # TODO: create a dedicated SelfArgument type for 'self'?
  339. input_args: tuple[PythonArgument, ...]
  340. # Keyword arguments excluding the 'out' argument and scattered kwargs belonging
  341. # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
  342. input_kwargs: tuple[PythonArgument, ...]
  343. output_args: PythonOutArgument | None
  344. # Return types, which are only used by pyi
  345. returns: PythonReturns
  346. # These are scattered kwargs arguments belonging to TensorOptions.
  347. # When binding to C++, they are packed into a TensorOptions object 'options'.
  348. # It's possible that the C++ signature doesn't take TensorOptions object (e.g.
  349. # for out variant), in which case they will be used as scattered fields without
  350. # being packed into 'options'.
  351. # TODO: maybe create a PythonTensorOptionsArgument?
  352. tensor_options_args: tuple[PythonArgument, ...]
  353. # method or function signature?
  354. method: bool
  355. @property
  356. def deprecated(self) -> bool:
  357. return False
  358. def arguments(
  359. self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
  360. ) -> tuple[PythonArgument | PythonOutArgument, ...]:
  361. result: list[PythonArgument | PythonOutArgument] = []
  362. result.extend(self.input_args)
  363. result.extend(self.input_kwargs)
  364. if self.output_args is not None and not skip_outputs:
  365. result.append(self.output_args)
  366. if not skip_tensor_options:
  367. result.extend(self.tensor_options_args)
  368. return tuple(result)
  369. def arguments_count(self) -> int:
  370. return len(self.arguments())
  371. def output_idx(self) -> int:
  372. return len(self.input_args) + len(self.input_kwargs)
  373. # [old codegen] Compute the Python function signature for argument parsing,
  374. # as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
  375. # this is NOT the same type signature as specified by PEP 484
  376. # as understood by mypy; our format was independently developed
  377. # and has some quirks to make it more suitable specifically
  378. # for error parsing.
  379. #
  380. # For a translation to mypy-valid type signatures, see
  381. # signature_str_pyi().
  382. def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
  383. args = self.arguments(skip_outputs=skip_outputs)
  384. schema_formals: list[str] = [
  385. a.argument_str(method=self.method, symint=symint) for a in args
  386. ]
  387. positional_argc = len(self.input_args)
  388. if len(schema_formals) > positional_argc:
  389. schema_formals.insert(positional_argc, "*")
  390. return f"{self.name}({', '.join(schema_formals)})"
  391. def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
  392. args = self.arguments(skip_outputs=skip_outputs)
  393. schema_formals: list[str] = [
  394. a.argument_str_pyi(method=self.method) for a in args
  395. ]
  396. positional_argc = len(self.input_args)
  397. if len(schema_formals) > positional_argc:
  398. schema_formals.insert(positional_argc, "*")
  399. # only pyi signatures include returns
  400. returns_str = returns_str_pyi(self)
  401. # pyi also includes self (with no typing/defaults) for methods
  402. if self.method:
  403. schema_formals.insert(0, "self")
  404. return format_function_signature(self.name, schema_formals, returns_str)
  405. def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
  406. # only pyi uses vararg signatures
  407. args = self.arguments(skip_outputs=skip_outputs)
  408. schema_formals: list[str] = [
  409. a.argument_str_pyi(method=self.method) for a in args
  410. ]
  411. # vararg only applies to pyi signatures. vararg variants are not generated for all signatures
  412. num_args = self.arguments_count()
  413. if num_args == 0:
  414. return None
  415. num_positionalargs = len(self.input_args)
  416. vararg_type = args[0].type
  417. if not (
  418. isinstance(vararg_type, ListType)
  419. and str(vararg_type.elem) in ["int", "SymInt"]
  420. and num_positionalargs == 1
  421. ):
  422. return None
  423. # Below are the major changes in vararg vs. regular pyi signatures
  424. # vararg signatures also omit the asterix
  425. if not isinstance(vararg_type, ListType):
  426. raise AssertionError(f"Expected ListType, got {type(vararg_type)}")
  427. schema_formals[0] = (
  428. "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem)
  429. )
  430. returns_str = returns_str_pyi(self)
  431. # pyi also includes self (with no typing/defaults) for methods
  432. if self.method:
  433. schema_formals.insert(0, "self")
  434. return format_function_signature(self.name, schema_formals, returns_str)
  435. # The deprecated python signature involves some special logic, so create a
  436. # dedicated data model to store these extra properties.
  437. @dataclass(frozen=True)
  438. class PythonSignatureDeprecated(PythonSignature):
  439. # Schema for the deprecated function
  440. deprecated_schema: FunctionSchema
  441. # The deprecated signature might miss some arguments that the corresponding
  442. # C++ signature expects. We need store the constant default values to pass in.
  443. # For example:
  444. # [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2)
  445. # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
  446. # [func call]: self.addmm(mat1, mat2, beta, 1)
  447. # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
  448. deprecated_args_exprs: tuple[str, ...]
  449. @property
  450. def deprecated(self) -> bool:
  451. return True
  452. def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
  453. return (
  454. PythonSignature.signature_str(
  455. self, skip_outputs=skip_outputs, symint=symint
  456. )
  457. + "|deprecated"
  458. )
  459. def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
  460. args = self.arguments(skip_outputs=skip_outputs)
  461. schema_formals: list[str] = [
  462. a.argument_str_pyi(method=self.method, deprecated=True) for a in args
  463. ]
  464. positional_argc = len(self.input_args)
  465. if len(schema_formals) > positional_argc:
  466. schema_formals.insert(positional_argc, "*")
  467. returns_str = returns_str_pyi(self)
  468. return format_function_signature(self.name, schema_formals, returns_str)
  469. def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
  470. # the codegen doesn't include vararg variants for deprecated signatures
  471. return None
  472. # This struct is used to hold the PythonSignature and its corresponding
  473. # NativeFunction BEFORE grouping base and out-variant functions.
  474. # Why not store NativeFunction in PythonSignature or construct PythonSignature
  475. # from NativeFunction? Because they are not 1-1 mapped.
  476. # One native function could have both deprecated and non-deprecated python
  477. # signatures - NativeFunction doesn't contain information to construct the
  478. # deprecated python signature.
  479. # One python signature is used to handle both the base and the out-variant
  480. # function - see 'PythonSignatureGroup'.
  481. @dataclass(frozen=True)
  482. class PythonSignatureNativeFunctionPair:
  483. signature: PythonSignature
  484. function: NativeFunction
  485. # We merge pairs of functions with signatures that are equivalent mod
  486. # output arguments, and use a single entry in the python_arg_parser sig
  487. # list for both (output arguments become optional).
  488. @dataclass(frozen=True)
  489. class PythonSignatureGroup:
  490. # The signature used for Python argument parsing. The outplace signature
  491. # is preferred if exists, because it can be used to parse inputs for both
  492. # the out-place variant and the base version (with output omitted).
  493. signature: PythonSignature
  494. # The regular ATen declaration (e.g. conv2d)
  495. base: NativeFunction
  496. # The out variant (e.g. conv2d_out)
  497. outplace: NativeFunction | None
  498. @classmethod
  499. def from_pairs(
  500. cls,
  501. functional: PythonSignatureNativeFunctionPair,
  502. out: PythonSignatureNativeFunctionPair | None,
  503. ) -> PythonSignatureGroup:
  504. if out is None:
  505. return PythonSignatureGroup(
  506. signature=functional.signature,
  507. base=functional.function,
  508. outplace=None,
  509. )
  510. # prefer the signature with optional out=... arguments because it's the
  511. # superset that can be used to parse input for both base and outplace.
  512. signature_kwargs = out.signature.__dict__.copy()
  513. # Out overloads in C++ don't have TensorOptions arguments,
  514. # so take these from the functional variant
  515. signature_kwargs["tensor_options_args"] = (
  516. functional.signature.tensor_options_args
  517. )
  518. return PythonSignatureGroup(
  519. signature=type(out.signature)(**signature_kwargs),
  520. base=functional.function,
  521. outplace=out.function,
  522. )
  523. # C++ function dispatch is wrapped in a lambda function. The lambda function
  524. # has almost the same signature as the C++ function, only with some small
  525. # variants - see details below.
  526. # This data model is used to represent arguments of the lambda function
  527. # signature.
  528. @dataclass(frozen=True)
  529. class DispatchLambdaArgument:
  530. name: str
  531. type_str: str
  532. is_out_arg: bool
  533. # To pass PyObjects arguments to C++ function (via the lambda wrapper),
  534. # we need first convert PyObjects into simple C++ objects. This work
  535. # is done by PythonArgParser.
  536. # This data model is used to represent the output of PythonArgParser.
  537. # It has 1-1 mapping with PythonArgument in PythonSignature.
  538. @dataclass(frozen=True)
  539. class PythonArgParserOutputExpr:
  540. # argument name
  541. name: str
  542. # RHS expression to reference PythonArgParser output.
  543. expr: str
  544. # In some special cases we need create different expr, e.g.:
  545. # '_r.isNone(1)' instead of '_r.tensor(1)'.
  546. index: int
  547. # The python argument it maps to.
  548. argument: PythonArgument
  549. @property
  550. def is_none_expr(self) -> str:
  551. return f"_r.isNone({self.index})"
  552. # To pass PythonArgParser output to the lambda wrapper, we need bind
  553. # PythonArgParserOutputExpr to DispatchLambdaArgument.
  554. # They are not always 1-1 mapped, e.g. scattered TensorOptions fields
  555. # need be packed into a TensorOptions object, which is the argument
  556. # that the lambda function wrapper takes.
  557. @dataclass(frozen=True)
  558. class DispatchLambdaArgumentExprs:
  559. # The exprs that provide the binding for lambda arguments, e.g.:
  560. #
  561. # 'self' -> '_r.tensor(0)'
  562. # 'min' -> 'out[0]' / 'min_indices' -> 'out[1]'
  563. # 'options' -> 'options'
  564. #
  565. # It has 1-1 mapping with DispatchLambdaArgument.
  566. exprs: Sequence[str]
  567. # Special local inits, which might introduce new variables that
  568. # the 'exprs' above reference, e.g.:
  569. #
  570. # 'auto out = _r.tensorlist_n<2>(2);'
  571. #
  572. inits: Sequence[str]
  573. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  574. #
  575. # Helper Functions
  576. #
  577. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  578. def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
  579. return CppSignatureGroup.from_native_function(f, method=method).signature
  580. def has_tensor_options(f: NativeFunction) -> bool:
  581. return f.func.arguments.tensor_options is not None
  582. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  583. #
  584. # Python Signature
  585. #
  586. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  587. # 'simple_type' was introduced by the old codegen, which is slightly
  588. # different from the python schema type, e.g.: doesn't have '?' suffix
  589. # for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
  590. def argument_type_str(
  591. t: Type, *, simple_type: bool = False, symint: bool = True
  592. ) -> str:
  593. if isinstance(t, BaseType):
  594. if t.name == BaseTy.int:
  595. return "int64_t"
  596. elif t.name == BaseTy.float:
  597. return "double"
  598. elif t.name == BaseTy.str:
  599. return "c10::string_view"
  600. elif t.name in [
  601. BaseTy.Tensor,
  602. BaseTy.bool,
  603. BaseTy.QScheme,
  604. BaseTy.Scalar,
  605. BaseTy.ScalarType,
  606. BaseTy.Generator,
  607. BaseTy.Storage,
  608. BaseTy.Layout,
  609. BaseTy.Device,
  610. BaseTy.DeviceIndex,
  611. BaseTy.MemoryFormat,
  612. BaseTy.Dimname,
  613. BaseTy.Stream,
  614. BaseTy.SymInt,
  615. ]:
  616. # These python schema type names line up with their function schema names
  617. return t.name.name
  618. elif isinstance(t, OptionalType):
  619. elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
  620. return f"{elem}?"
  621. elif isinstance(t, ListType):
  622. size = t.size if not simple_type else None
  623. if str(t.elem) == "bool":
  624. if t.size is None:
  625. raise AssertionError("bool ListType must have a size")
  626. return f"::std::array<bool,{t.size}>"
  627. elif str(t.elem) == "int":
  628. return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
  629. elif str(t.elem) == "SymInt":
  630. if symint:
  631. return (
  632. f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
  633. )
  634. else:
  635. return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
  636. elif str(t.elem) == "Tensor":
  637. return f"TensorList[{size}]" if size is not None else "TensorList"
  638. elif str(t.elem) == "Scalar":
  639. return f"ScalarList[{size}]" if size is not None else "ScalarList"
  640. elif str(t.elem) == "Tensor?":
  641. if simple_type:
  642. return "c10::List<::std::optional<Tensor>>"
  643. else:
  644. return "const c10::List<::std::optional<Tensor>> &"
  645. elif str(t.elem) == "Dimname":
  646. return f"DimnameList[{size}]" if size is not None else "DimnameList"
  647. elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
  648. return f"ArrayRef<{elem}>"
  649. raise RuntimeError(f"unrecognized type {repr(t)}")
  650. def argument_type_size(t: Type) -> int | None:
  651. l = t.is_list_like()
  652. if l is not None and str(l.elem) != "bool":
  653. return l.size
  654. else:
  655. return None
  656. def argument(a: Argument) -> PythonArgument:
  657. return PythonArgument(
  658. name=a.name,
  659. type=a.type,
  660. # TODO: directly translate a.default to python default
  661. default=(
  662. str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False)))
  663. if a.default is not None
  664. else None
  665. ),
  666. default_init=None,
  667. )
  668. # Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
  669. def signature(
  670. f: NativeFunction, *, method: bool = False, pyi: bool = False
  671. ) -> PythonSignature:
  672. return signature_from_schema(
  673. f.func, category_override=f.category_override, method=method, pyi=pyi
  674. )
  675. def signature_from_schema(
  676. func: FunctionSchema,
  677. *,
  678. category_override: str | None,
  679. method: bool = False,
  680. pyi: bool = False,
  681. ) -> PythonSignature:
  682. args: list[Argument] = []
  683. args.extend(func.arguments.pre_self_positional)
  684. # Skip SelfArgument if this is method.
  685. if not method and func.arguments.self_arg is not None:
  686. args.append(func.arguments.self_arg.argument)
  687. args.extend(func.arguments.post_self_positional)
  688. args.extend(func.arguments.pre_tensor_options_kwarg_only)
  689. # Skip TensorOptionsArguments. Python side TensorOptions
  690. # arguments are created based on different rules - see below.
  691. args.extend(func.arguments.post_tensor_options_kwarg_only)
  692. args.extend(func.arguments.out)
  693. input_arg_set = {a.name for a in func.arguments.flat_positional}
  694. kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only}
  695. out_arg_set = {a.name for a in func.arguments.out}
  696. input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
  697. input_kwargs = tuple(
  698. map(argument, filter(lambda a: a.name in kwarg_only_set, args))
  699. )
  700. outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))
  701. # Reintroduce the scattered fields of TensorOptions for Python.
  702. # Compared to the cpp counterpart, the python arguments have new property
  703. # (default_init) and a new argument 'requires_grad', which require some
  704. # special handlings.
  705. # [old codegen] TODO: because these aren't guaranteed to be 100% faithful
  706. # to the original versions in the yaml, this recreation is a potential
  707. # source of drift between eager and JIT. Pull this logic out to a shared place.
  708. has_tensor_input_arg = any(
  709. a.type.is_tensor_like() for a in func.arguments.flat_non_out
  710. )
  711. if any(a.name == "requires_grad" for a in func.schema_order_arguments()):
  712. raise ValueError(
  713. "argument named requires_grad is reserved, should not explicitly add it in the schema"
  714. )
  715. # [old codegen] this probably won't work if one of the returns is not a tensor,
  716. # but it will produce a compile-time error that is obvious.
  717. has_tensor_return = any(r.type.is_tensor_like() for r in func.returns)
  718. name: str = cpp.name(func)
  719. is_factory_function = category_override == "factory" or (
  720. has_tensor_return and not has_tensor_input_arg
  721. )
  722. is_like_or_new_function = (
  723. category_override in ("new", "like")
  724. or name.startswith("new_")
  725. or name.endswith("_like")
  726. )
  727. is_dummy_function = category_override == "dummy"
  728. tensor_options_args: list[PythonArgument] = []
  729. if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
  730. def topt_default_init(name: str) -> str | None:
  731. topt_args = func.arguments.tensor_options
  732. if topt_args is None:
  733. return None
  734. a = getattr(topt_args, name)
  735. if a.default is None or a.default == "None":
  736. return None
  737. return cpp.default_expr(a.default, a.type, symint=False)
  738. tensor_options_args.append(
  739. PythonArgument(
  740. name="dtype",
  741. type=OptionalType(BaseType(BaseTy.ScalarType)),
  742. default="None",
  743. default_init=(
  744. None if is_like_or_new_function else topt_default_init("dtype")
  745. ),
  746. )
  747. )
  748. tensor_options_args.append(
  749. PythonArgument(
  750. name="layout",
  751. type=OptionalType(BaseType(BaseTy.Layout)),
  752. default="None",
  753. default_init=(
  754. None if is_like_or_new_function else topt_default_init("layout")
  755. ),
  756. )
  757. )
  758. tensor_options_args.append(
  759. PythonArgument(
  760. name="device",
  761. type=OptionalType(BaseType(BaseTy.Device)),
  762. default="None",
  763. default_init=(
  764. None
  765. if is_like_or_new_function
  766. else (
  767. topt_default_init("device")
  768. or "torch::tensors::get_default_device()"
  769. )
  770. ),
  771. )
  772. )
  773. tensor_options_args.append(
  774. PythonArgument(
  775. name="pin_memory",
  776. type=OptionalType(BaseType(BaseTy.bool)),
  777. default="False",
  778. default_init=None,
  779. )
  780. )
  781. tensor_options_args.append(
  782. PythonArgument(
  783. name="requires_grad",
  784. type=OptionalType(BaseType(BaseTy.bool)),
  785. default="False",
  786. default_init=None,
  787. )
  788. )
  789. returns = PythonReturns(returns=func.returns)
  790. return PythonSignature(
  791. name=str(func.name.name),
  792. input_args=input_args,
  793. input_kwargs=input_kwargs,
  794. output_args=PythonOutArgument.from_outputs(outputs),
  795. tensor_options_args=tuple(tensor_options_args),
  796. returns=returns,
  797. method=method,
  798. )
  799. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  800. #
  801. # Python Interface
  802. #
  803. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  804. def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
  805. if len(returns) <= 1 or all(r.name is None for r in returns):
  806. return []
  807. else:
  808. if any(r.name is None for r in returns):
  809. # When building on Windows, `PyStructSequence_UnnamedField` could not be
  810. # resolved by the linker for some reason, which cause error in building:
  811. #
  812. # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
  813. # PyStructSequence_UnnamedField
  814. #
  815. # Thus, at this point in time, we do not support unnamed
  816. # fields in structseq; you must either name all fields,
  817. # or none of them.
  818. raise ValueError("Unnamed field is not supported by codegen")
  819. return [str(r.name) for r in returns]
  820. def argument_type_str_pyi(t: Type) -> str:
  821. add_optional = False
  822. if isinstance(t, OptionalType):
  823. t = t.elem
  824. add_optional = True
  825. ret = ""
  826. if isinstance(t, BaseType):
  827. if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
  828. ret = "_int"
  829. if t.name == BaseTy.SymInt:
  830. ret = "_int | SymInt"
  831. elif t.name == BaseTy.float:
  832. ret = "_float"
  833. elif t.name == BaseTy.str:
  834. ret = "str"
  835. elif t.name == BaseTy.Scalar:
  836. ret = "Number | _complex"
  837. elif t.name == BaseTy.ScalarType:
  838. ret = "_dtype"
  839. elif t.name == BaseTy.bool:
  840. ret = "_bool"
  841. elif t.name == BaseTy.QScheme:
  842. ret = "_qscheme"
  843. elif t.name == BaseTy.Layout:
  844. ret = "_layout"
  845. elif t.name == BaseTy.Device:
  846. ret = "DeviceLikeType | None"
  847. elif t.name == BaseTy.MemoryFormat:
  848. ret = "memory_format"
  849. elif t.name == BaseTy.Dimname:
  850. ret = "str | EllipsisType | None"
  851. elif t.name == BaseTy.Storage:
  852. ret = "Storage | UntypedStorage"
  853. elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]:
  854. # These python schema type names line up with their function schema names
  855. ret = t.name.name
  856. elif isinstance(t, ListType):
  857. if str(t.elem) == "int":
  858. ret = "_int | _size" if t.size is not None else "_size"
  859. elif t.is_tensor_like():
  860. # TODO: this doesn't seem right...
  861. # Tensor?[] currently translates to tuple[Tensor, ...] | list[Tensor] | None
  862. # It should probably translate to tuple[Tensor | None, ...] | list[Tensor | None]
  863. add_optional = True
  864. ret = (
  865. "Tensor | tuple[Tensor, ...] | list[Tensor]"
  866. if t.size is not None
  867. else "tuple[Tensor, ...] | list[Tensor]"
  868. )
  869. elif str(t.elem) == "float":
  870. ret = "Sequence[_float]"
  871. elif str(t.elem) == "SymInt" and t.size is not None:
  872. elem = argument_type_str_pyi(t.elem)
  873. ret = f"{elem} | Sequence[{elem}]"
  874. else:
  875. elem = argument_type_str_pyi(t.elem)
  876. ret = f"Sequence[{elem}]"
  877. else:
  878. raise RuntimeError(f"unrecognized type {repr(t)}")
  879. if add_optional:
  880. ret = f"{ret} | None".replace(" | None | None", " | None")
  881. return ret
  882. def return_type_str_pyi(t: Type) -> str:
  883. # Where arguments are open to accepting Union, return types should return
  884. # concrete types
  885. if isinstance(t, OptionalType):
  886. inner = return_type_str_pyi(t.elem)
  887. return f"{inner} | None".replace(" | None | None", " | None")
  888. if isinstance(t, BaseType):
  889. if t.name == BaseTy.Device:
  890. return "_device"
  891. elif t.name == BaseTy.Dimname:
  892. return "str | None"
  893. else:
  894. return argument_type_str_pyi(t)
  895. if isinstance(t, ListType):
  896. inner = return_type_str_pyi(t.elem)
  897. return f"tuple[{inner}, ...]"
  898. return argument_type_str_pyi(t)
  899. def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
  900. python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
  901. structseq_name = signature.name
  902. field_names = structseq_fieldnames(signature.returns.returns)
  903. if field_names:
  904. # These types are structseq objects which act like named NamedTuples, but
  905. # the constructor acts like the constructor of tuple. Using typing.NamedTuple
  906. # does not allow us to override __init__.
  907. seq_type = f"tuple[{', '.join(python_returns)}]"
  908. structseq_def_lines = [
  909. f"class {structseq_name}({seq_type}): # fmt: skip",
  910. ]
  911. for name, ret_type in zip(field_names, python_returns):
  912. structseq_def_lines.extend(
  913. [
  914. " @property",
  915. f" def {name}(self) -> {ret_type}: ...",
  916. ]
  917. )
  918. structseq_def_lines.extend(
  919. [
  920. " def __new__(",
  921. " cls,",
  922. f" sequence: {seq_type},",
  923. " ) -> Self: # fmt: skip",
  924. " ...",
  925. f" n_fields: Final[_int] = {len(field_names)}",
  926. f" n_sequence_fields: Final[_int] = {len(field_names)}",
  927. " n_unnamed_fields: Final[_int] = 0",
  928. " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
  929. "", # add an extra newline
  930. ]
  931. )
  932. structseq_def = "\n".join(structseq_def_lines)
  933. # Example:
  934. # structseq_def = (
  935. # "class max(tuple[Tensor, Tensor]): # fmt: skip\n"
  936. # " @property\n"
  937. # " def values(self) -> Tensor: ...\n"
  938. # " @property\n"
  939. # " def indices(self) -> Tensor: ...\n"
  940. # " def __new__(\n"
  941. # " cls,\n"
  942. # " sequence: tuple[Tensor, Tensor],\n"
  943. # " ) -> Self: # fmt: skip\n"
  944. # " ...\n"
  945. # " n_fields: Final[_int] = 2",
  946. # " n_sequence_fields: Final[_int] = 2",
  947. # " n_unnamed_fields: Final[_int] = 0",
  948. # " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
  949. # )
  950. return structseq_name, structseq_def
  951. return None
  952. def returns_str_pyi(signature: PythonSignature) -> str:
  953. field_names = structseq_fieldnames(signature.returns.returns)
  954. if field_names:
  955. return f"torch.return_types.{signature.name}"
  956. python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
  957. if len(python_returns) > 1:
  958. return "tuple[" + ", ".join(python_returns) + "]"
  959. if len(python_returns) == 1:
  960. return python_returns[0]
  961. return "None"
  962. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  963. #
  964. # C++ Function Dispatch
  965. #
  966. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  967. # This section provides APIs to generate the code that does C++ function
  968. # dispatch. The C++ function call is wrapped by a lambda function.
  969. # For example:
  970. #
  971. # // aten::selu_(Tensor(a!) self) -> Tensor(a!)
  972. # auto dispatch_selu_ = [](Tensor self) -> Tensor {
  973. # pybind11::gil_scoped_release no_gil;
  974. # return at::selu_(self);
  975. # };
  976. #
  977. # The lambda function's signature follows the C++ signature in common
  978. # cases, e.g.:
  979. #
  980. # // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
  981. # [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
  982. #
  983. # For out variant the 'out' argument's type is changed from 'Tensor &'
  984. # to 'Tensor'. It's because when calling the lambda it passes in the
  985. # PythonArgParser output '_r.tensor(3)', which is stack allocated object
  986. # and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'.
  987. #
  988. # // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
  989. # [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
  990. #
  991. # For multi-output case it can keep using reference type because the
  992. # PythonArgParser output has been unpacked to local variables, e.g.:
  993. #
  994. # // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *,
  995. # // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
  996. # [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple<Tensor,Tensor>
  997. #
  998. # For deprecated python signature, it should follow deprecated python arg order.
  999. # TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary?
  1000. def dispatch_lambda_args(
  1001. ps: PythonSignature, f: NativeFunction, symint: bool = True
  1002. ) -> tuple[DispatchLambdaArgument, ...]:
  1003. if isinstance(ps, PythonSignatureDeprecated):
  1004. schema = ps.deprecated_schema
  1005. else:
  1006. schema = f.func
  1007. # Start with cpp arguments - dispatch lambda signature always include 'self'
  1008. cpp_args = cpp.arguments(
  1009. arguments=schema.arguments,
  1010. faithful=False,
  1011. symint=symint,
  1012. method=False,
  1013. cpp_no_default_args=f.cpp_no_default_args,
  1014. )
  1015. out_args: set[str] = {a.name for a in schema.arguments.out}
  1016. # Convert from cpp argument to lambda argument
  1017. def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
  1018. type_str = cpp_arg.type
  1019. is_out_arg = cpp_arg.name in out_args
  1020. if ps.method and cpp_arg.name == "self":
  1021. # For method's 'self', we can use 'const Tensor &' and simply ignore mutability!
  1022. type_str = "const at::Tensor &"
  1023. else:
  1024. # For other cases we need prevent dangling refs to temps (unless it's
  1025. # unpacked scattered output)
  1026. # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'.
  1027. # TODO: avoid this special handling?
  1028. ensure_temp_safe = len(out_args) <= 1 or not is_out_arg
  1029. if ensure_temp_safe:
  1030. type_str = {
  1031. "at::Tensor &": "at::Tensor",
  1032. }.get(type_str, type_str)
  1033. return DispatchLambdaArgument(
  1034. name=cpp_arg.name,
  1035. type_str=type_str,
  1036. is_out_arg=is_out_arg,
  1037. )
  1038. return tuple(map(dispatch_lambda_arg, cpp_args))
  1039. # [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean
  1040. # it's enough to just extend the list here. Before you do this, make sure
  1041. # to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
  1042. SUPPORTED_RETURN_TYPES = {
  1043. "at::Tensor",
  1044. "::std::tuple<at::Tensor,at::Tensor>",
  1045. "::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
  1046. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
  1047. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
  1048. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
  1049. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
  1050. "::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
  1051. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",
  1052. "::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>",
  1053. "::std::tuple<double,int64_t>",
  1054. "::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
  1055. "::std::vector<at::Tensor>",
  1056. # Needed for flash attention forw/backward
  1057. "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>",
  1058. "at::Scalar",
  1059. "bool",
  1060. "int64_t",
  1061. "void*",
  1062. "void",
  1063. "at::QScheme",
  1064. "double",
  1065. "at::IntArrayRef",
  1066. "at::ScalarType",
  1067. "at::Stream",
  1068. }
  1069. def dispatch_lambda_return_str(f: NativeFunction) -> str:
  1070. # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &')
  1071. # because the dispatch lambdas take mutable arguments *by value*, not
  1072. # by reference. If you then return a reference to such an argument, you
  1073. # will now have a pointer to a dangling stack entry. Not good.
  1074. #
  1075. # You want:
  1076. #
  1077. # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); };
  1078. # ^^^^^^
  1079. #
  1080. # *not*
  1081. #
  1082. # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); };
  1083. # ^^^^^^^
  1084. #
  1085. # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
  1086. # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a
  1087. # mutable reference to temporary. Maybe we could assign it to a
  1088. # variable itself.)
  1089. returns_without_annotation = tuple(
  1090. Return(r.name, r.type, None) for r in f.func.returns
  1091. )
  1092. return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
  1093. if return_str not in SUPPORTED_RETURN_TYPES:
  1094. raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
  1095. return return_str
  1096. def cpp_dispatch_target(f: NativeFunction) -> str:
  1097. symint = f.func.has_symint()
  1098. name = cpp.name(f.func, symint_overload=symint)
  1099. if Variant.method in f.variants:
  1100. return f"self.{name}"
  1101. if Variant.function in f.variants:
  1102. if has_tensor_options(f) or f.func.name.name.base.endswith("_like"):
  1103. namespace = "torch"
  1104. else:
  1105. namespace = "at"
  1106. return f"{namespace}::{name}"
  1107. raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}")
  1108. def cpp_dispatch_exprs(
  1109. f: NativeFunction,
  1110. *,
  1111. python_signature: PythonSignature | None = None,
  1112. ) -> tuple[str, ...]:
  1113. cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
  1114. exprs: tuple[str, ...] = ()
  1115. if not isinstance(python_signature, PythonSignatureDeprecated):
  1116. # By default the exprs are consistent with the C++ signature.
  1117. exprs = tuple(a.name for a in cpp_args)
  1118. else:
  1119. # For deprecated python signature we may need fill in some constants.
  1120. exprs = tuple(
  1121. filter(
  1122. lambda n: n != "out" or f.func.is_out_fn(),
  1123. python_signature.deprecated_args_exprs,
  1124. )
  1125. )
  1126. if Variant.method in f.variants:
  1127. exprs = tuple(filter("self".__ne__, exprs))
  1128. return exprs
  1129. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  1130. #
  1131. # Python / C++ Args Binding
  1132. #
  1133. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  1134. # We explicitly enumerate the PythonArgParser unpacking methods for all
  1135. # supported types. This might be more verbose than necessary, partially
  1136. # because of the irregularity of unpacking method naming, partially
  1137. # because we want to mimic the old codegen behavior - to reject
  1138. # unexpected and/or unsupported cases which the old codegen rejects.
  1139. # For certain cases it is intentionally more restrictive than necessary,
  1140. # e.g.: it doesn't accepts doublelist with definite size.
  1141. def arg_parser_unpack_method(
  1142. t: Type, default: str | None, default_init: str | None, *, symint: bool = True
  1143. ) -> str:
  1144. has_default_init = default_init is not None
  1145. if has_default_init and str(t) not in (
  1146. "ScalarType?",
  1147. "ScalarType",
  1148. "Device",
  1149. "Device?",
  1150. "Layout",
  1151. "Layout?",
  1152. "bool",
  1153. "bool?",
  1154. ):
  1155. raise RuntimeError(f"type '{t}' does not supported unpacking with default")
  1156. if isinstance(t, BaseType):
  1157. if t.name in [
  1158. BaseTy.Tensor,
  1159. BaseTy.Stream,
  1160. BaseTy.Storage,
  1161. BaseTy.Scalar,
  1162. BaseTy.Dimname,
  1163. ]:
  1164. # These unpack methods line up with their schema names
  1165. return t.name.name.lower()
  1166. elif t.name == BaseTy.ScalarType:
  1167. return "scalartypeWithDefault" if has_default_init else "scalartype"
  1168. elif t.name == BaseTy.Device:
  1169. return "deviceWithDefault" if has_default_init else "device"
  1170. elif t.name == BaseTy.DeviceIndex:
  1171. return "toInt64"
  1172. elif t.name == BaseTy.int:
  1173. return "toInt64"
  1174. elif t.name == BaseTy.SymInt:
  1175. return "toSymInt" if symint else "toInt64"
  1176. elif t.name == BaseTy.bool:
  1177. return "toBoolWithDefault" if has_default_init else "toBool"
  1178. elif t.name == BaseTy.float:
  1179. return "toDouble"
  1180. elif t.name == BaseTy.str:
  1181. return "stringView"
  1182. elif t.name == BaseTy.Layout:
  1183. return "layoutWithDefault" if has_default_init else "layout"
  1184. elif t.name == BaseTy.MemoryFormat:
  1185. return "memoryformat"
  1186. elif isinstance(t, OptionalType):
  1187. if str(t.elem) == "Tensor":
  1188. return "optionalTensor"
  1189. elif str(t.elem) == "Generator":
  1190. return "generator"
  1191. elif str(t.elem) == "Dimname[]":
  1192. return "toDimnameListOptional"
  1193. elif not has_default_init and default in (
  1194. None,
  1195. "None",
  1196. "::std::nullopt",
  1197. "std::nullopt",
  1198. ):
  1199. # If default is None: append 'Optional' to elem's unpacking method
  1200. return (
  1201. arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
  1202. )
  1203. else:
  1204. # Otherwise, load as underlying type with default
  1205. return arg_parser_unpack_method(
  1206. t.elem, default, default_init, symint=symint
  1207. )
  1208. elif isinstance(t, ListType):
  1209. if str(t.elem) == "Tensor":
  1210. # accept and use definite size
  1211. return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist"
  1212. elif str(t.elem) == "Tensor?":
  1213. return "list_of_optional_tensors"
  1214. elif str(t.elem) == "Dimname":
  1215. # accept definite size
  1216. return "dimnamelist"
  1217. elif str(t.elem) == "int":
  1218. # accept definite size
  1219. return "intlist"
  1220. elif str(t.elem) == "float":
  1221. return "doublelist"
  1222. elif str(t.elem) == "SymInt":
  1223. # accept definite size
  1224. return "symintlist" if symint else "intlist"
  1225. elif str(t.elem) == "Scalar":
  1226. return "scalarlist"
  1227. raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
  1228. # Return RHS expression for python argument using PythonArgParser output.
  1229. # e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
  1230. def arg_parser_output_expr(
  1231. arg_index: int, a: PythonArgument, *, symint: bool = True
  1232. ) -> PythonArgParserOutputExpr:
  1233. has_default = a.default_init is not None
  1234. unpack_method = arg_parser_unpack_method(
  1235. t=a.type, default=a.default, default_init=a.default_init, symint=symint
  1236. )
  1237. default = f", {a.default_init}" if has_default else ""
  1238. expr = f"_r.{unpack_method}({arg_index}{default})"
  1239. return PythonArgParserOutputExpr(
  1240. name=a.name,
  1241. expr=expr,
  1242. index=arg_index,
  1243. argument=a,
  1244. )
  1245. # Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
  1246. def arg_parser_output_exprs(
  1247. ps: PythonSignature, f: NativeFunction, *, symint: bool = True
  1248. ) -> dict[str, PythonArgParserOutputExpr]:
  1249. return {
  1250. e.name: e
  1251. for i, a in enumerate(ps.arguments())
  1252. for e in (arg_parser_output_expr(i, a, symint=symint),)
  1253. }
  1254. # argument name to type for scattered tensor options fields
  1255. TENSOR_OPTIONS_FIELDS = {
  1256. "dtype": "ScalarType?",
  1257. "device": "Device?",
  1258. "layout": "Layout?",
  1259. "pin_memory": "bool?",
  1260. "requires_grad": "bool?",
  1261. }
  1262. # bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
  1263. def dispatch_lambda_exprs(
  1264. ps: PythonSignature, f: NativeFunction, *, symint: bool = True
  1265. ) -> DispatchLambdaArgumentExprs:
  1266. # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
  1267. # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
  1268. # outputs.
  1269. arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
  1270. lambda_args = dispatch_lambda_args(ps, f, symint=symint)
  1271. inits: list[str] = []
  1272. lambda_args_exprs: dict[str, str] = {}
  1273. has_toptions = has_tensor_options(f)
  1274. # 1. special inits/unpacking to provide binding exprs for lambda arguments.
  1275. for a in ps.arguments(skip_tensor_options=True):
  1276. name = a.name
  1277. arg_parser_expr = arg_parser_outputs[a.name].expr
  1278. if has_toptions and name == "self":
  1279. # TODO: why this needs to be special case?
  1280. inits.extend(
  1281. [
  1282. f"auto self = {arg_parser_expr};",
  1283. ]
  1284. )
  1285. lambda_args_exprs[name] = name
  1286. elif (
  1287. isinstance(a, PythonOutArgument)
  1288. and len(a.outputs) > 1
  1289. and f.func.is_out_fn()
  1290. ):
  1291. inits.extend(
  1292. [
  1293. f"auto out = {arg_parser_expr};",
  1294. ]
  1295. )
  1296. for i, out_arg in enumerate(a.outputs):
  1297. lambda_args_exprs[out_arg.name] = f"out[{i}]"
  1298. elif str(a.type) == "Dimname[]?":
  1299. # [old codegen]
  1300. # TODO: make this part of something more general, or get rid of it.
  1301. # optional<ArrayRef<T>> are special. The PythonArgParser returns an
  1302. # optional<vector<T>>, which cannot be implicitly converted to
  1303. # optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap.
  1304. inits.extend(
  1305. [
  1306. f"auto __{name} = {arg_parser_expr};",
  1307. f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950
  1308. ]
  1309. )
  1310. lambda_args_exprs[name] = name
  1311. else:
  1312. # default case - directly using PythonArgParser output expr
  1313. lambda_args_exprs[name] = arg_parser_expr
  1314. # method's self is passed directly to python binding, rather than parsed
  1315. if ps.method:
  1316. lambda_args_exprs["self"] = "self"
  1317. # 2. special packing/checking for TensorOptions.
  1318. tensor_options_args_names = [a.name for a in ps.tensor_options_args]
  1319. if has_toptions:
  1320. if f.func.is_out_fn():
  1321. raise RuntimeError(f"{f.func}: tensor options with output arg")
  1322. for a in ps.tensor_options_args:
  1323. if a.name not in TENSOR_OPTIONS_FIELDS:
  1324. raise RuntimeError(
  1325. f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments"
  1326. )
  1327. if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name):
  1328. raise RuntimeError(
  1329. f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
  1330. )
  1331. if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS):
  1332. raise RuntimeError(
  1333. f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
  1334. )
  1335. inits.append(
  1336. f"""\
  1337. const auto options = TensorOptions()
  1338. .dtype({arg_parser_outputs["dtype"].expr})
  1339. .device({arg_parser_outputs["device"].expr})
  1340. .layout({arg_parser_outputs["layout"].expr})
  1341. .requires_grad({arg_parser_outputs["requires_grad"].expr})
  1342. .pinned_memory({arg_parser_outputs["pin_memory"].expr});
  1343. torch::utils::maybe_initialize_device(options);
  1344. """
  1345. )
  1346. lambda_args_exprs["options"] = "options"
  1347. # 3. special case - access scattered TensorOptions fields without packing
  1348. # TODO: maybe move to the generator side as it's not related to binding.
  1349. if not has_toptions and tensor_options_args_names:
  1350. if "dtype" in tensor_options_args_names:
  1351. # we're an output-arg variant, check these args against output tensor
  1352. if not f.func.is_out_fn():
  1353. raise RuntimeError(
  1354. f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}"
  1355. )
  1356. if not all(a in tensor_options_args_names for a in ("layout", "device")):
  1357. raise RuntimeError(
  1358. f"{f.func}: incomplete tensor options for output check"
  1359. )
  1360. inits.append(
  1361. f"""\
  1362. check_out_type_matches({arg_parser_outputs["out"].expr}, {arg_parser_outputs["dtype"].expr},
  1363. {arg_parser_outputs["dtype"].is_none_expr}, {arg_parser_outputs["layout"].expr},
  1364. {arg_parser_outputs["device"].expr}, {arg_parser_outputs["device"].is_none_expr});
  1365. """
  1366. )
  1367. # we'll set requires_grad on outgoing tensor
  1368. if "requires_grad" not in tensor_options_args_names:
  1369. raise RuntimeError(
  1370. f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]'
  1371. )
  1372. return DispatchLambdaArgumentExprs(
  1373. exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args),
  1374. inits=inits,
  1375. )