| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550 |
- from __future__ import annotations
- from dataclasses import dataclass
- from typing import TYPE_CHECKING
- from torchgen.api import cpp
- from torchgen.api.types import Binding, CppSignature, CppSignatureGroup
- from torchgen.gen import pythonify_default
- from torchgen.model import (
- Argument,
- BaseTy,
- BaseType,
- FunctionSchema,
- ListType,
- NativeFunction,
- OptionalType,
- Return,
- Type,
- Variant,
- )
- if TYPE_CHECKING:
- from collections.abc import Iterable, Sequence
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # Data Models
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # [Notes] python binding codegen
- #
- # The Python binding codegen produces code that takes the input list of
- # PyObjects, finds the matching ATen C++ function using PythonArgParser,
- # converts the PyObjects into C++ types and calls the ATen C++ function:
- #
- # +--------+ parsing +------------------------+ binding +-----------------------+
- # | PyObjs | ---------> | PythonArgParser Output | ---------> | Cpp Function Dispatch |
- # +--------+ +------------------------+ +-----------------------+
- #
- # The following examples demonstrate the data models the Python binding
- # codegen needs to deal with and the tasks it needs to accomplish. It
- # helps understand the purpose of the new data types we introduced below.
- #
- # - Function Schema (source of truth)
- #
- # aten::empty.names(int[] size, *, Dimname[]? names,
- # ScalarType? dtype=None, Layout? layout=None,
- # Device? device=None, bool? pin_memory=None,
- # MemoryFormat? memory_format=None) -> Tensor
- #
- # - Python Signature
- #
- # It's used to generate input schema string for PythonArgParser.
- # Note: TensorOptions fields are reordered and the additional
- # 'requires_grad' field is added:
- #
- # empty(IntArrayRef size, *, DimnameList? names,
- # MemoryFormat? memory_format=None, ScalarType dtype=None,
- # Layout layout=torch.strided, Device device=None,
- # bool pin_memory=False, bool requires_grad=False)
- #
- # - C++ Signature
- #
- # It's used to generate C++ lambda formals & dispatch call.
- # Note: the scattered TensorOptions fields are packed into 'options'.
- #
- # auto dispatch_empty =
- # [](IntArrayRef size, std::optional<DimnameList> names,
- # const TensorOptions & options,
- # std::optional<MemoryFormat> memory_format) -> Tensor {
- # pybind11::gil_scoped_release no_gil;
- # return torch::empty(size, names, options, memory_format);
- # };
- #
- # - Binding between Python Arguments and C++ Arguments
- #
- # Given a set of Python Arguments in scope, we need produce the
- # binding expressions that translate the Python API into C++ API:
- #
- # Python Args Cpp Args Binding Exprs
- # -----------------------------------------------------------------
- # 0: size size '_r.intlist(0)'
- # 1: names names 'names' [special init]
- # 2: memory_format -------+
- # 3: dtype -----+-|--> options 'options' [special packing]
- # 4: layout / |
- # 5: device / +--> memory_format '_r.memoryformatOptional(2)'
- # 6: pin_memory /
- # 7: requires_grad -+
- #
- # So the full dispatch expression would look like:
- #
- # dispatch_empty(_r.intlist(0), names, options,
- # _r.memoryformatOptional(2))
- #
- # Where does 'names' come from? It involves special local init:
- #
- # auto __names = _r.toDimnameListOptional(1);
- # std::optional<DimnameList> names =
- # __names ? std::make_optional(DimnameList(__names.value()))
- # : std::nullopt;
- #
- # Where does 'options' come from? It involves special local init
- # for TensorOptions. Note that Python side has the additional
- # 'requires_grad' field:
- #
- # const auto options = TensorOptions()
- # .dtype(_r.scalartype(3))
- # .device(_r.device(5))
- # .layout(_r.layoutOptional(4))
- # .requires_grad(_r.toBool(7))
- # .pinned_memory(_r.toBool(6));
- #
- # In some other cases one Python Argument can map to multiple C++
- # Arguments. For example:
- #
- # aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False)
- # -> (Tensor values, Tensor indices)
- #
- # Python Args Cpp Args Binding Exprs
- # ---------------------------------------------------------------------
- # +----> max 'out[0]'
- # /-----> max_values 'out[1]
- # 0: input / self '_r.tensor(0)'
- # 1: dim / dim '_r.dimname(1)'
- # 2: keepdim / keepdim '_r.toBool(2)'
- # 3: out -----+ [local init] out '_r.tensorlist_n<2>(3)'
- #
- # As demonstrated above, the binding can involve reordering,
- # packing, unpacking and special local inits.
- #
- #
- # Let's look at a concrete example:
- #
- # static PythonArgParser parser({
- # "abs(Tensor input, *, Tensor out=None)",
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # ^
- # +--- Python Schema, represented by PythonSignature and PythonArgument
- #
- # }, /*traceable=*/true);
- #
- # ParsedArgs<2> parsed_args;
- # auto _r = parser.parse(nullptr, args, kwargs, parsed_args);
- #
- # ...
- #
- # if (_r.isNone(1)) {
- # ~~~~~~~~~~~~ <--- Scattered PythonArgParser output (arg name = 'out')
- # represented by PythonArgParserOutputExpr
- #
- # // aten::abs(Tensor self) -> Tensor
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # ^
- # +--- NativeFunction schema, base version
- #
- # auto dispatch_abs = [](const Tensor & self) -> Tensor {
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # ^
- # +--- dispatch_lambda_args / dispatch_lambda_return_str
- # generated from NativeFunction / CppSignature
- # (deprecated PythonSignature is special)
- # arguments are represented by DispatchLambdaArgument
- #
- # pybind11::gil_scoped_release no_gil;
- # return self.abs();
- # ~~~~~~~~~~~ <--- cpp_dispatch_target / cpp_dispatch_exprs
- # generated from NativeFunction / CppSignature
- # };
- # return wrap(dispatch_abs(_r.tensor(0)));
- # ~~~~~~~~~~~~~
- # ^
- # +--- dispatch_lambda_exprs
- # binding PythonArgParserOutputExpr (python args)
- # and DispatchLambdaArgument (c++ args)
- #
- # } else {
- # // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # ^
- # +--- NativeFunction schema, out-variant
- #
- # auto dispatch_abs_out = [](Tensor out, const Tensor & self) -> Tensor {
- # pybind11::gil_scoped_release no_gil;
- # return at::abs_out(out, self);
- # };
- # return wrap(dispatch_abs_out(_r.tensor(1), _r.tensor(0)));
- # }
- #
- #
- # [Notes] python interface codegen
- # The python dataclasses below are used used to generate both python binding code
- # and pyi type hint signatures.
- # In theory these two should look very similar, but there are number of differences
- # in how pyi signatures vs. python_arg_parser signatures are generated.
- # These differences have been encapsulated in signature_str() vs. signature_str_pyi()
- # to display the full signatures, and argument_str() vs argument_str_pyi() to display arguments.
- # For examples, only pyi signatures include return types.
- def format_function_signature(
- name: str, arguments: Iterable[str] = (), return_type: str | None = None
- ) -> str:
- if not isinstance(arguments, (list, tuple)):
- arguments = tuple(arguments)
- return_type = f" -> {return_type}" if return_type is not None else ""
- sig = f"def {name}({', '.join(arguments)}){return_type}: ..."
- if len(sig) <= 80 or len(arguments) == 0 or tuple(arguments) == ("self",):
- return sig
- lines = [
- f"def {name}(",
- *(f" {arg}," for arg in arguments),
- f"){return_type}: ...",
- ]
- sig = "\n".join(lines)
- if all(len(line) <= 80 for line in lines):
- return sig
- # ruff format bug for compound statements: https://github.com/astral-sh/ruff/issues/18658
- # use `skip` instead of `on` + `off`
- return sig.removesuffix(" ...") + " # fmt: skip\n ..."
- @dataclass(frozen=True)
- class PythonReturns:
- returns: tuple[Return, ...]
- @dataclass(frozen=True)
- class PythonArgument:
- name: str
- type: Type
- default: str | None
- # Used to generate the default init expr for some PythonArgParser outputs, e.g.:
- #
- # _r.layoutWithDefault(3, layout_from_backend(self.options().backend())))
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
- # ^
- # +--- default_init str
- default_init: str | None
- # Compute argument formal for python argument parsing.
- # Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
- def argument_str(self, *, method: bool = False, symint: bool = True) -> str:
- type_str = (
- argument_type_str(self.type, symint=symint)
- .replace("const ", "")
- .replace(" &", "")
- )
- name = self.name
- # s/self/input/ outside method bindings
- # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
- # for the parse string
- if name == "self" and type_str in ["Tensor", "Number"] and not method:
- name = "input"
- # add default
- if self.default is not None:
- default = {
- "nullptr": "None",
- "::std::nullopt": "None",
- "std::nullopt": "None",
- "{}": "None",
- }.get(self.default, self.default)
- return f"{type_str} {name}={default}"
- else:
- return f"{type_str} {name}"
- def argument_str_pyi(
- self, *, method: bool = False, deprecated: bool = False
- ) -> str:
- type_str = argument_type_str_pyi(self.type)
- name = self.name
- # s/self/input/ outside method bindings
- # [old codegen] TODO: remove this? doesn't rename in codegen, it's just
- # for the parse string
- if name == "self" and type_str == "Tensor" and not method and not deprecated:
- name = "input"
- if name == "from": # from is a Python keyword...
- name += "_"
- # pyi merges the _out and functional variants into the same signature, with an optional out arg
- if name == "out" and type_str == "Tensor" and not deprecated:
- type_str = f"{type_str} | None".replace(" | None | None", " | None")
- # pyi deprecated signatures don't get defaults for their out arg
- treat_as_no_default = (
- deprecated
- and isinstance(self, PythonOutArgument)
- and self.default == "None"
- )
- # add default
- if self.default is not None and not treat_as_no_default:
- if (
- isinstance(self.type, ListType)
- and self.type.elem == BaseType(BaseTy.int)
- and self.default.startswith("{")
- and self.default.endswith("}")
- ):
- default = (
- "(" + ", ".join(map(str.strip, self.default[1:-1].split(","))) + ")"
- )
- else:
- default = {
- "nullptr": "None",
- "::std::nullopt": "None",
- "std::nullopt": "None",
- "{}": "None",
- "c10::MemoryFormat::Contiguous": "contiguous_format",
- "QScheme::PER_TENSOR_AFFINE": "per_tensor_affine",
- }.get(self.default, self.default)
- return f"{name}: {type_str} = {default}"
- else:
- return f"{name}: {type_str}"
- @dataclass(frozen=True)
- class PythonOutArgument(PythonArgument):
- # In Python signature multiple output fields are packed into one 'out' argument.
- # When binding to C++, it's first binded to a local 'out' variable:
- # 'auto out = _r.tensorlist_n<2>(2);',
- # then binded to scattered C++ output arguments as 'out[0]', 'out[1]', and etc.
- # TODO: maybe don't need keep scattered out fields for python signature?
- outputs: tuple[PythonArgument, ...]
- @staticmethod
- def from_outputs(outputs: tuple[PythonArgument, ...]) -> PythonOutArgument | None:
- if not outputs:
- return None
- size = len(outputs)
- if size == 1:
- return PythonOutArgument(
- name=outputs[0].name,
- type=outputs[0].type,
- default="None",
- default_init=None,
- outputs=outputs,
- )
- elif size > 1:
- if any(not a.type.is_tensor_like() for a in outputs):
- raise RuntimeError(f"Unsupported output type: {outputs}")
- return PythonOutArgument(
- name="out",
- # TODO: shouldn't this be OptionalType[ListType[...]], since it defaults to None?
- type=ListType(BaseType(BaseTy.Tensor), size),
- default="None",
- default_init=None,
- outputs=outputs,
- )
- raise AssertionError(r"Unexpected PythonOutArgument size")
- @dataclass(frozen=True)
- class PythonSignature:
- # Base operator name, without inplace/outplace suffix.
- name: str
- # Positional arguments.
- # TODO: create a dedicated SelfArgument type for 'self'?
- input_args: tuple[PythonArgument, ...]
- # Keyword arguments excluding the 'out' argument and scattered kwargs belonging
- # to TensorOptions (dtype, layout, device, pin_memory, requires_grad, etc).
- input_kwargs: tuple[PythonArgument, ...]
- output_args: PythonOutArgument | None
- # Return types, which are only used by pyi
- returns: PythonReturns
- # These are scattered kwargs arguments belonging to TensorOptions.
- # When binding to C++, they are packed into a TensorOptions object 'options'.
- # It's possible that the C++ signature doesn't take TensorOptions object (e.g.
- # for out variant), in which case they will be used as scattered fields without
- # being packed into 'options'.
- # TODO: maybe create a PythonTensorOptionsArgument?
- tensor_options_args: tuple[PythonArgument, ...]
- # method or function signature?
- method: bool
- @property
- def deprecated(self) -> bool:
- return False
- def arguments(
- self, *, skip_outputs: bool = False, skip_tensor_options: bool = False
- ) -> tuple[PythonArgument | PythonOutArgument, ...]:
- result: list[PythonArgument | PythonOutArgument] = []
- result.extend(self.input_args)
- result.extend(self.input_kwargs)
- if self.output_args is not None and not skip_outputs:
- result.append(self.output_args)
- if not skip_tensor_options:
- result.extend(self.tensor_options_args)
- return tuple(result)
- def arguments_count(self) -> int:
- return len(self.arguments())
- def output_idx(self) -> int:
- return len(self.input_args) + len(self.input_kwargs)
- # [old codegen] Compute the Python function signature for argument parsing,
- # as specified in torch/csrc/utils/python_arg_parser.h. WARNING:
- # this is NOT the same type signature as specified by PEP 484
- # as understood by mypy; our format was independently developed
- # and has some quirks to make it more suitable specifically
- # for error parsing.
- #
- # For a translation to mypy-valid type signatures, see
- # signature_str_pyi().
- def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
- args = self.arguments(skip_outputs=skip_outputs)
- schema_formals: list[str] = [
- a.argument_str(method=self.method, symint=symint) for a in args
- ]
- positional_argc = len(self.input_args)
- if len(schema_formals) > positional_argc:
- schema_formals.insert(positional_argc, "*")
- return f"{self.name}({', '.join(schema_formals)})"
- def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
- args = self.arguments(skip_outputs=skip_outputs)
- schema_formals: list[str] = [
- a.argument_str_pyi(method=self.method) for a in args
- ]
- positional_argc = len(self.input_args)
- if len(schema_formals) > positional_argc:
- schema_formals.insert(positional_argc, "*")
- # only pyi signatures include returns
- returns_str = returns_str_pyi(self)
- # pyi also includes self (with no typing/defaults) for methods
- if self.method:
- schema_formals.insert(0, "self")
- return format_function_signature(self.name, schema_formals, returns_str)
- def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
- # only pyi uses vararg signatures
- args = self.arguments(skip_outputs=skip_outputs)
- schema_formals: list[str] = [
- a.argument_str_pyi(method=self.method) for a in args
- ]
- # vararg only applies to pyi signatures. vararg variants are not generated for all signatures
- num_args = self.arguments_count()
- if num_args == 0:
- return None
- num_positionalargs = len(self.input_args)
- vararg_type = args[0].type
- if not (
- isinstance(vararg_type, ListType)
- and str(vararg_type.elem) in ["int", "SymInt"]
- and num_positionalargs == 1
- ):
- return None
- # Below are the major changes in vararg vs. regular pyi signatures
- # vararg signatures also omit the asterix
- if not isinstance(vararg_type, ListType):
- raise AssertionError(f"Expected ListType, got {type(vararg_type)}")
- schema_formals[0] = (
- "*" + args[0].name + ": " + argument_type_str_pyi(vararg_type.elem)
- )
- returns_str = returns_str_pyi(self)
- # pyi also includes self (with no typing/defaults) for methods
- if self.method:
- schema_formals.insert(0, "self")
- return format_function_signature(self.name, schema_formals, returns_str)
- # The deprecated python signature involves some special logic, so create a
- # dedicated data model to store these extra properties.
- @dataclass(frozen=True)
- class PythonSignatureDeprecated(PythonSignature):
- # Schema for the deprecated function
- deprecated_schema: FunctionSchema
- # The deprecated signature might miss some arguments that the corresponding
- # C++ signature expects. We need store the constant default values to pass in.
- # For example:
- # [deprecate signature]: addmm(Scalar beta, Tensor self, Tensor mat1, Tensor mat2)
- # [func schema]: aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
- # [func call]: self.addmm(mat1, mat2, beta, 1)
- # We store ['self', 'mat1', 'mat2', 'beta', '1'] in this case.
- deprecated_args_exprs: tuple[str, ...]
- @property
- def deprecated(self) -> bool:
- return True
- def signature_str(self, *, skip_outputs: bool = False, symint: bool = True) -> str:
- return (
- PythonSignature.signature_str(
- self, skip_outputs=skip_outputs, symint=symint
- )
- + "|deprecated"
- )
- def signature_str_pyi(self, *, skip_outputs: bool = False) -> str:
- args = self.arguments(skip_outputs=skip_outputs)
- schema_formals: list[str] = [
- a.argument_str_pyi(method=self.method, deprecated=True) for a in args
- ]
- positional_argc = len(self.input_args)
- if len(schema_formals) > positional_argc:
- schema_formals.insert(positional_argc, "*")
- returns_str = returns_str_pyi(self)
- return format_function_signature(self.name, schema_formals, returns_str)
- def signature_str_pyi_vararg(self, *, skip_outputs: bool = False) -> str | None:
- # the codegen doesn't include vararg variants for deprecated signatures
- return None
- # This struct is used to hold the PythonSignature and its corresponding
- # NativeFunction BEFORE grouping base and out-variant functions.
- # Why not store NativeFunction in PythonSignature or construct PythonSignature
- # from NativeFunction? Because they are not 1-1 mapped.
- # One native function could have both deprecated and non-deprecated python
- # signatures - NativeFunction doesn't contain information to construct the
- # deprecated python signature.
- # One python signature is used to handle both the base and the out-variant
- # function - see 'PythonSignatureGroup'.
- @dataclass(frozen=True)
- class PythonSignatureNativeFunctionPair:
- signature: PythonSignature
- function: NativeFunction
- # We merge pairs of functions with signatures that are equivalent mod
- # output arguments, and use a single entry in the python_arg_parser sig
- # list for both (output arguments become optional).
- @dataclass(frozen=True)
- class PythonSignatureGroup:
- # The signature used for Python argument parsing. The outplace signature
- # is preferred if exists, because it can be used to parse inputs for both
- # the out-place variant and the base version (with output omitted).
- signature: PythonSignature
- # The regular ATen declaration (e.g. conv2d)
- base: NativeFunction
- # The out variant (e.g. conv2d_out)
- outplace: NativeFunction | None
- @classmethod
- def from_pairs(
- cls,
- functional: PythonSignatureNativeFunctionPair,
- out: PythonSignatureNativeFunctionPair | None,
- ) -> PythonSignatureGroup:
- if out is None:
- return PythonSignatureGroup(
- signature=functional.signature,
- base=functional.function,
- outplace=None,
- )
- # prefer the signature with optional out=... arguments because it's the
- # superset that can be used to parse input for both base and outplace.
- signature_kwargs = out.signature.__dict__.copy()
- # Out overloads in C++ don't have TensorOptions arguments,
- # so take these from the functional variant
- signature_kwargs["tensor_options_args"] = (
- functional.signature.tensor_options_args
- )
- return PythonSignatureGroup(
- signature=type(out.signature)(**signature_kwargs),
- base=functional.function,
- outplace=out.function,
- )
- # C++ function dispatch is wrapped in a lambda function. The lambda function
- # has almost the same signature as the C++ function, only with some small
- # variants - see details below.
- # This data model is used to represent arguments of the lambda function
- # signature.
- @dataclass(frozen=True)
- class DispatchLambdaArgument:
- name: str
- type_str: str
- is_out_arg: bool
- # To pass PyObjects arguments to C++ function (via the lambda wrapper),
- # we need first convert PyObjects into simple C++ objects. This work
- # is done by PythonArgParser.
- # This data model is used to represent the output of PythonArgParser.
- # It has 1-1 mapping with PythonArgument in PythonSignature.
- @dataclass(frozen=True)
- class PythonArgParserOutputExpr:
- # argument name
- name: str
- # RHS expression to reference PythonArgParser output.
- expr: str
- # In some special cases we need create different expr, e.g.:
- # '_r.isNone(1)' instead of '_r.tensor(1)'.
- index: int
- # The python argument it maps to.
- argument: PythonArgument
- @property
- def is_none_expr(self) -> str:
- return f"_r.isNone({self.index})"
- # To pass PythonArgParser output to the lambda wrapper, we need bind
- # PythonArgParserOutputExpr to DispatchLambdaArgument.
- # They are not always 1-1 mapped, e.g. scattered TensorOptions fields
- # need be packed into a TensorOptions object, which is the argument
- # that the lambda function wrapper takes.
- @dataclass(frozen=True)
- class DispatchLambdaArgumentExprs:
- # The exprs that provide the binding for lambda arguments, e.g.:
- #
- # 'self' -> '_r.tensor(0)'
- # 'min' -> 'out[0]' / 'min_indices' -> 'out[1]'
- # 'options' -> 'options'
- #
- # It has 1-1 mapping with DispatchLambdaArgument.
- exprs: Sequence[str]
- # Special local inits, which might introduce new variables that
- # the 'exprs' above reference, e.g.:
- #
- # 'auto out = _r.tensorlist_n<2>(2);'
- #
- inits: Sequence[str]
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # Helper Functions
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
- return CppSignatureGroup.from_native_function(f, method=method).signature
- def has_tensor_options(f: NativeFunction) -> bool:
- return f.func.arguments.tensor_options is not None
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # Python Signature
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- # 'simple_type' was introduced by the old codegen, which is slightly
- # different from the python schema type, e.g.: doesn't have '?' suffix
- # for optional Tensor/TensorList; doesn't have '[size]' suffix for list type.
- def argument_type_str(
- t: Type, *, simple_type: bool = False, symint: bool = True
- ) -> str:
- if isinstance(t, BaseType):
- if t.name == BaseTy.int:
- return "int64_t"
- elif t.name == BaseTy.float:
- return "double"
- elif t.name == BaseTy.str:
- return "c10::string_view"
- elif t.name in [
- BaseTy.Tensor,
- BaseTy.bool,
- BaseTy.QScheme,
- BaseTy.Scalar,
- BaseTy.ScalarType,
- BaseTy.Generator,
- BaseTy.Storage,
- BaseTy.Layout,
- BaseTy.Device,
- BaseTy.DeviceIndex,
- BaseTy.MemoryFormat,
- BaseTy.Dimname,
- BaseTy.Stream,
- BaseTy.SymInt,
- ]:
- # These python schema type names line up with their function schema names
- return t.name.name
- elif isinstance(t, OptionalType):
- elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
- return f"{elem}?"
- elif isinstance(t, ListType):
- size = t.size if not simple_type else None
- if str(t.elem) == "bool":
- if t.size is None:
- raise AssertionError("bool ListType must have a size")
- return f"::std::array<bool,{t.size}>"
- elif str(t.elem) == "int":
- return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
- elif str(t.elem) == "SymInt":
- if symint:
- return (
- f"SymIntArrayRef[{size}]" if size is not None else "SymIntArrayRef"
- )
- else:
- return f"IntArrayRef[{size}]" if size is not None else "IntArrayRef"
- elif str(t.elem) == "Tensor":
- return f"TensorList[{size}]" if size is not None else "TensorList"
- elif str(t.elem) == "Scalar":
- return f"ScalarList[{size}]" if size is not None else "ScalarList"
- elif str(t.elem) == "Tensor?":
- if simple_type:
- return "c10::List<::std::optional<Tensor>>"
- else:
- return "const c10::List<::std::optional<Tensor>> &"
- elif str(t.elem) == "Dimname":
- return f"DimnameList[{size}]" if size is not None else "DimnameList"
- elem = argument_type_str(t.elem, simple_type=simple_type, symint=symint)
- return f"ArrayRef<{elem}>"
- raise RuntimeError(f"unrecognized type {repr(t)}")
- def argument_type_size(t: Type) -> int | None:
- l = t.is_list_like()
- if l is not None and str(l.elem) != "bool":
- return l.size
- else:
- return None
- def argument(a: Argument) -> PythonArgument:
- return PythonArgument(
- name=a.name,
- type=a.type,
- # TODO: directly translate a.default to python default
- default=(
- str(pythonify_default(cpp.default_expr(a.default, a.type, symint=False)))
- if a.default is not None
- else None
- ),
- default_init=None,
- )
- # Generates a PythonSignature that can be used for either .pyi or PythonArgParser codegen
- def signature(
- f: NativeFunction, *, method: bool = False, pyi: bool = False
- ) -> PythonSignature:
- return signature_from_schema(
- f.func, category_override=f.category_override, method=method, pyi=pyi
- )
- def signature_from_schema(
- func: FunctionSchema,
- *,
- category_override: str | None,
- method: bool = False,
- pyi: bool = False,
- ) -> PythonSignature:
- args: list[Argument] = []
- args.extend(func.arguments.pre_self_positional)
- # Skip SelfArgument if this is method.
- if not method and func.arguments.self_arg is not None:
- args.append(func.arguments.self_arg.argument)
- args.extend(func.arguments.post_self_positional)
- args.extend(func.arguments.pre_tensor_options_kwarg_only)
- # Skip TensorOptionsArguments. Python side TensorOptions
- # arguments are created based on different rules - see below.
- args.extend(func.arguments.post_tensor_options_kwarg_only)
- args.extend(func.arguments.out)
- input_arg_set = {a.name for a in func.arguments.flat_positional}
- kwarg_only_set = {a.name for a in func.arguments.flat_kwarg_only}
- out_arg_set = {a.name for a in func.arguments.out}
- input_args = tuple(map(argument, filter(lambda a: a.name in input_arg_set, args)))
- input_kwargs = tuple(
- map(argument, filter(lambda a: a.name in kwarg_only_set, args))
- )
- outputs = tuple(map(argument, filter(lambda a: a.name in out_arg_set, args)))
- # Reintroduce the scattered fields of TensorOptions for Python.
- # Compared to the cpp counterpart, the python arguments have new property
- # (default_init) and a new argument 'requires_grad', which require some
- # special handlings.
- # [old codegen] TODO: because these aren't guaranteed to be 100% faithful
- # to the original versions in the yaml, this recreation is a potential
- # source of drift between eager and JIT. Pull this logic out to a shared place.
- has_tensor_input_arg = any(
- a.type.is_tensor_like() for a in func.arguments.flat_non_out
- )
- if any(a.name == "requires_grad" for a in func.schema_order_arguments()):
- raise ValueError(
- "argument named requires_grad is reserved, should not explicitly add it in the schema"
- )
- # [old codegen] this probably won't work if one of the returns is not a tensor,
- # but it will produce a compile-time error that is obvious.
- has_tensor_return = any(r.type.is_tensor_like() for r in func.returns)
- name: str = cpp.name(func)
- is_factory_function = category_override == "factory" or (
- has_tensor_return and not has_tensor_input_arg
- )
- is_like_or_new_function = (
- category_override in ("new", "like")
- or name.startswith("new_")
- or name.endswith("_like")
- )
- is_dummy_function = category_override == "dummy"
- tensor_options_args: list[PythonArgument] = []
- if (is_factory_function or is_like_or_new_function) and not is_dummy_function:
- def topt_default_init(name: str) -> str | None:
- topt_args = func.arguments.tensor_options
- if topt_args is None:
- return None
- a = getattr(topt_args, name)
- if a.default is None or a.default == "None":
- return None
- return cpp.default_expr(a.default, a.type, symint=False)
- tensor_options_args.append(
- PythonArgument(
- name="dtype",
- type=OptionalType(BaseType(BaseTy.ScalarType)),
- default="None",
- default_init=(
- None if is_like_or_new_function else topt_default_init("dtype")
- ),
- )
- )
- tensor_options_args.append(
- PythonArgument(
- name="layout",
- type=OptionalType(BaseType(BaseTy.Layout)),
- default="None",
- default_init=(
- None if is_like_or_new_function else topt_default_init("layout")
- ),
- )
- )
- tensor_options_args.append(
- PythonArgument(
- name="device",
- type=OptionalType(BaseType(BaseTy.Device)),
- default="None",
- default_init=(
- None
- if is_like_or_new_function
- else (
- topt_default_init("device")
- or "torch::tensors::get_default_device()"
- )
- ),
- )
- )
- tensor_options_args.append(
- PythonArgument(
- name="pin_memory",
- type=OptionalType(BaseType(BaseTy.bool)),
- default="False",
- default_init=None,
- )
- )
- tensor_options_args.append(
- PythonArgument(
- name="requires_grad",
- type=OptionalType(BaseType(BaseTy.bool)),
- default="False",
- default_init=None,
- )
- )
- returns = PythonReturns(returns=func.returns)
- return PythonSignature(
- name=str(func.name.name),
- input_args=input_args,
- input_kwargs=input_kwargs,
- output_args=PythonOutArgument.from_outputs(outputs),
- tensor_options_args=tuple(tensor_options_args),
- returns=returns,
- method=method,
- )
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # Python Interface
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- def structseq_fieldnames(returns: tuple[Return, ...]) -> list[str]:
- if len(returns) <= 1 or all(r.name is None for r in returns):
- return []
- else:
- if any(r.name is None for r in returns):
- # When building on Windows, `PyStructSequence_UnnamedField` could not be
- # resolved by the linker for some reason, which cause error in building:
- #
- # python_nn_functions.cpp.obj : error LNK2001: unresolved external symbol
- # PyStructSequence_UnnamedField
- #
- # Thus, at this point in time, we do not support unnamed
- # fields in structseq; you must either name all fields,
- # or none of them.
- raise ValueError("Unnamed field is not supported by codegen")
- return [str(r.name) for r in returns]
- def argument_type_str_pyi(t: Type) -> str:
- add_optional = False
- if isinstance(t, OptionalType):
- t = t.elem
- add_optional = True
- ret = ""
- if isinstance(t, BaseType):
- if t.name in [BaseTy.int, BaseTy.DeviceIndex]:
- ret = "_int"
- if t.name == BaseTy.SymInt:
- ret = "_int | SymInt"
- elif t.name == BaseTy.float:
- ret = "_float"
- elif t.name == BaseTy.str:
- ret = "str"
- elif t.name == BaseTy.Scalar:
- ret = "Number | _complex"
- elif t.name == BaseTy.ScalarType:
- ret = "_dtype"
- elif t.name == BaseTy.bool:
- ret = "_bool"
- elif t.name == BaseTy.QScheme:
- ret = "_qscheme"
- elif t.name == BaseTy.Layout:
- ret = "_layout"
- elif t.name == BaseTy.Device:
- ret = "DeviceLikeType | None"
- elif t.name == BaseTy.MemoryFormat:
- ret = "memory_format"
- elif t.name == BaseTy.Dimname:
- ret = "str | EllipsisType | None"
- elif t.name == BaseTy.Storage:
- ret = "Storage | UntypedStorage"
- elif t.name in [BaseTy.Tensor, BaseTy.Generator, BaseTy.Stream]:
- # These python schema type names line up with their function schema names
- ret = t.name.name
- elif isinstance(t, ListType):
- if str(t.elem) == "int":
- ret = "_int | _size" if t.size is not None else "_size"
- elif t.is_tensor_like():
- # TODO: this doesn't seem right...
- # Tensor?[] currently translates to tuple[Tensor, ...] | list[Tensor] | None
- # It should probably translate to tuple[Tensor | None, ...] | list[Tensor | None]
- add_optional = True
- ret = (
- "Tensor | tuple[Tensor, ...] | list[Tensor]"
- if t.size is not None
- else "tuple[Tensor, ...] | list[Tensor]"
- )
- elif str(t.elem) == "float":
- ret = "Sequence[_float]"
- elif str(t.elem) == "SymInt" and t.size is not None:
- elem = argument_type_str_pyi(t.elem)
- ret = f"{elem} | Sequence[{elem}]"
- else:
- elem = argument_type_str_pyi(t.elem)
- ret = f"Sequence[{elem}]"
- else:
- raise RuntimeError(f"unrecognized type {repr(t)}")
- if add_optional:
- ret = f"{ret} | None".replace(" | None | None", " | None")
- return ret
- def return_type_str_pyi(t: Type) -> str:
- # Where arguments are open to accepting Union, return types should return
- # concrete types
- if isinstance(t, OptionalType):
- inner = return_type_str_pyi(t.elem)
- return f"{inner} | None".replace(" | None | None", " | None")
- if isinstance(t, BaseType):
- if t.name == BaseTy.Device:
- return "_device"
- elif t.name == BaseTy.Dimname:
- return "str | None"
- else:
- return argument_type_str_pyi(t)
- if isinstance(t, ListType):
- inner = return_type_str_pyi(t.elem)
- return f"tuple[{inner}, ...]"
- return argument_type_str_pyi(t)
- def returns_structseq_pyi(signature: PythonSignature) -> tuple[str, str] | None:
- python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
- structseq_name = signature.name
- field_names = structseq_fieldnames(signature.returns.returns)
- if field_names:
- # These types are structseq objects which act like named NamedTuples, but
- # the constructor acts like the constructor of tuple. Using typing.NamedTuple
- # does not allow us to override __init__.
- seq_type = f"tuple[{', '.join(python_returns)}]"
- structseq_def_lines = [
- f"class {structseq_name}({seq_type}): # fmt: skip",
- ]
- for name, ret_type in zip(field_names, python_returns):
- structseq_def_lines.extend(
- [
- " @property",
- f" def {name}(self) -> {ret_type}: ...",
- ]
- )
- structseq_def_lines.extend(
- [
- " def __new__(",
- " cls,",
- f" sequence: {seq_type},",
- " ) -> Self: # fmt: skip",
- " ...",
- f" n_fields: Final[_int] = {len(field_names)}",
- f" n_sequence_fields: Final[_int] = {len(field_names)}",
- " n_unnamed_fields: Final[_int] = 0",
- " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
- "", # add an extra newline
- ]
- )
- structseq_def = "\n".join(structseq_def_lines)
- # Example:
- # structseq_def = (
- # "class max(tuple[Tensor, Tensor]): # fmt: skip\n"
- # " @property\n"
- # " def values(self) -> Tensor: ...\n"
- # " @property\n"
- # " def indices(self) -> Tensor: ...\n"
- # " def __new__(\n"
- # " cls,\n"
- # " sequence: tuple[Tensor, Tensor],\n"
- # " ) -> Self: # fmt: skip\n"
- # " ...\n"
- # " n_fields: Final[_int] = 2",
- # " n_sequence_fields: Final[_int] = 2",
- # " n_unnamed_fields: Final[_int] = 0",
- # " def __init_subclass__(cls) -> NoReturn: ... # prohibit subclassing",
- # )
- return structseq_name, structseq_def
- return None
- def returns_str_pyi(signature: PythonSignature) -> str:
- field_names = structseq_fieldnames(signature.returns.returns)
- if field_names:
- return f"torch.return_types.{signature.name}"
- python_returns = [return_type_str_pyi(r.type) for r in signature.returns.returns]
- if len(python_returns) > 1:
- return "tuple[" + ", ".join(python_returns) + "]"
- if len(python_returns) == 1:
- return python_returns[0]
- return "None"
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # C++ Function Dispatch
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- # This section provides APIs to generate the code that does C++ function
- # dispatch. The C++ function call is wrapped by a lambda function.
- # For example:
- #
- # // aten::selu_(Tensor(a!) self) -> Tensor(a!)
- # auto dispatch_selu_ = [](Tensor self) -> Tensor {
- # pybind11::gil_scoped_release no_gil;
- # return at::selu_(self);
- # };
- #
- # The lambda function's signature follows the C++ signature in common
- # cases, e.g.:
- #
- # // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
- # [](const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
- #
- # For out variant the 'out' argument's type is changed from 'Tensor &'
- # to 'Tensor'. It's because when calling the lambda it passes in the
- # PythonArgParser output '_r.tensor(3)', which is stack allocated object
- # and needs to pass by value. Also see comments in 'dispatch_lambda_return_str()'.
- #
- # // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
- # [](Tensor out, const Tensor & self, const Tensor & other, Scalar alpha) -> Tensor
- #
- # For multi-output case it can keep using reference type because the
- # PythonArgParser output has been unpacked to local variables, e.g.:
- #
- # // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *,
- # // Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)
- # [](Tensor & max, Tensor & max_values, const Tensor & self, Dimname dim, bool keepdim) -> std::tuple<Tensor,Tensor>
- #
- # For deprecated python signature, it should follow deprecated python arg order.
- # TODO: This is to keep same byte-for-byte result as the old codegen - maybe unnecessary?
- def dispatch_lambda_args(
- ps: PythonSignature, f: NativeFunction, symint: bool = True
- ) -> tuple[DispatchLambdaArgument, ...]:
- if isinstance(ps, PythonSignatureDeprecated):
- schema = ps.deprecated_schema
- else:
- schema = f.func
- # Start with cpp arguments - dispatch lambda signature always include 'self'
- cpp_args = cpp.arguments(
- arguments=schema.arguments,
- faithful=False,
- symint=symint,
- method=False,
- cpp_no_default_args=f.cpp_no_default_args,
- )
- out_args: set[str] = {a.name for a in schema.arguments.out}
- # Convert from cpp argument to lambda argument
- def dispatch_lambda_arg(cpp_arg: Binding) -> DispatchLambdaArgument:
- type_str = cpp_arg.type
- is_out_arg = cpp_arg.name in out_args
- if ps.method and cpp_arg.name == "self":
- # For method's 'self', we can use 'const Tensor &' and simply ignore mutability!
- type_str = "const at::Tensor &"
- else:
- # For other cases we need prevent dangling refs to temps (unless it's
- # unpacked scattered output)
- # The reason is explained in the comments above and in 'dispatch_lambda_return_str()'.
- # TODO: avoid this special handling?
- ensure_temp_safe = len(out_args) <= 1 or not is_out_arg
- if ensure_temp_safe:
- type_str = {
- "at::Tensor &": "at::Tensor",
- }.get(type_str, type_str)
- return DispatchLambdaArgument(
- name=cpp_arg.name,
- type_str=type_str,
- is_out_arg=is_out_arg,
- )
- return tuple(map(dispatch_lambda_arg, cpp_args))
- # [old codegen] XXX: if you got here because of an assertion failure, it doesn't mean
- # it's enough to just extend the list here. Before you do this, make sure
- # to add an appropriate wrap() overload in torch/csrc/autograd/utils/wrap_outputs.h.
- SUPPORTED_RETURN_TYPES = {
- "at::Tensor",
- "::std::tuple<at::Tensor,at::Tensor>",
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor>",
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor,at::Tensor>",
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor,int64_t>",
- "::std::tuple<at::Tensor,at::Tensor,double,int64_t>",
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,int64_t>",
- "::std::tuple<at::Tensor,at::Tensor,double,at::Tensor,int64_t>",
- "::std::tuple<double,int64_t>",
- "::std::tuple<at::Tensor,::std::vector<at::Tensor>>",
- "::std::vector<at::Tensor>",
- # Needed for flash attention forw/backward
- "::std::tuple<at::Tensor,at::Tensor,at::Tensor,at::Tensor,c10::SymInt,c10::SymInt,at::Tensor,at::Tensor,at::Tensor>",
- "at::Scalar",
- "bool",
- "int64_t",
- "void*",
- "void",
- "at::QScheme",
- "double",
- "at::IntArrayRef",
- "at::ScalarType",
- "at::Stream",
- }
- def dispatch_lambda_return_str(f: NativeFunction) -> str:
- # [old codegen] Remove type annotation (e.g. 'Tensor' rather than 'Tensor &')
- # because the dispatch lambdas take mutable arguments *by value*, not
- # by reference. If you then return a reference to such an argument, you
- # will now have a pointer to a dangling stack entry. Not good.
- #
- # You want:
- #
- # auto dispatch_selu_ = [](Tensor self) -> Tensor { ...; return at::selu_(self); };
- # ^^^^^^
- #
- # *not*
- #
- # auto dispatch_selu_ = [](Tensor self) -> Tensor& { ...; return at::selu_(self); };
- # ^^^^^^^
- #
- # (NB: We can't make dispatch_selu_ take Tensor&, because the enclosing
- # codegen looks like dispatch_selu_(_r.tensor(0)), and you can't take a
- # mutable reference to temporary. Maybe we could assign it to a
- # variable itself.)
- returns_without_annotation = tuple(
- Return(r.name, r.type, None) for r in f.func.returns
- )
- return_str = cpp.returns_type(returns_without_annotation, symint=True).cpp_type()
- if return_str not in SUPPORTED_RETURN_TYPES:
- raise RuntimeError(f"{f.func.name} returns unsupported type {return_str}")
- return return_str
- def cpp_dispatch_target(f: NativeFunction) -> str:
- symint = f.func.has_symint()
- name = cpp.name(f.func, symint_overload=symint)
- if Variant.method in f.variants:
- return f"self.{name}"
- if Variant.function in f.variants:
- if has_tensor_options(f) or f.func.name.name.base.endswith("_like"):
- namespace = "torch"
- else:
- namespace = "at"
- return f"{namespace}::{name}"
- raise RuntimeError(f"could not dispatch, neither function nor method: {f.func}")
- def cpp_dispatch_exprs(
- f: NativeFunction,
- *,
- python_signature: PythonSignature | None = None,
- ) -> tuple[str, ...]:
- cpp_args: Sequence[Binding] = _cpp_signature(f, method=False).arguments()
- exprs: tuple[str, ...] = ()
- if not isinstance(python_signature, PythonSignatureDeprecated):
- # By default the exprs are consistent with the C++ signature.
- exprs = tuple(a.name for a in cpp_args)
- else:
- # For deprecated python signature we may need fill in some constants.
- exprs = tuple(
- filter(
- lambda n: n != "out" or f.func.is_out_fn(),
- python_signature.deprecated_args_exprs,
- )
- )
- if Variant.method in f.variants:
- exprs = tuple(filter("self".__ne__, exprs))
- return exprs
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # Python / C++ Args Binding
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- # We explicitly enumerate the PythonArgParser unpacking methods for all
- # supported types. This might be more verbose than necessary, partially
- # because of the irregularity of unpacking method naming, partially
- # because we want to mimic the old codegen behavior - to reject
- # unexpected and/or unsupported cases which the old codegen rejects.
- # For certain cases it is intentionally more restrictive than necessary,
- # e.g.: it doesn't accepts doublelist with definite size.
- def arg_parser_unpack_method(
- t: Type, default: str | None, default_init: str | None, *, symint: bool = True
- ) -> str:
- has_default_init = default_init is not None
- if has_default_init and str(t) not in (
- "ScalarType?",
- "ScalarType",
- "Device",
- "Device?",
- "Layout",
- "Layout?",
- "bool",
- "bool?",
- ):
- raise RuntimeError(f"type '{t}' does not supported unpacking with default")
- if isinstance(t, BaseType):
- if t.name in [
- BaseTy.Tensor,
- BaseTy.Stream,
- BaseTy.Storage,
- BaseTy.Scalar,
- BaseTy.Dimname,
- ]:
- # These unpack methods line up with their schema names
- return t.name.name.lower()
- elif t.name == BaseTy.ScalarType:
- return "scalartypeWithDefault" if has_default_init else "scalartype"
- elif t.name == BaseTy.Device:
- return "deviceWithDefault" if has_default_init else "device"
- elif t.name == BaseTy.DeviceIndex:
- return "toInt64"
- elif t.name == BaseTy.int:
- return "toInt64"
- elif t.name == BaseTy.SymInt:
- return "toSymInt" if symint else "toInt64"
- elif t.name == BaseTy.bool:
- return "toBoolWithDefault" if has_default_init else "toBool"
- elif t.name == BaseTy.float:
- return "toDouble"
- elif t.name == BaseTy.str:
- return "stringView"
- elif t.name == BaseTy.Layout:
- return "layoutWithDefault" if has_default_init else "layout"
- elif t.name == BaseTy.MemoryFormat:
- return "memoryformat"
- elif isinstance(t, OptionalType):
- if str(t.elem) == "Tensor":
- return "optionalTensor"
- elif str(t.elem) == "Generator":
- return "generator"
- elif str(t.elem) == "Dimname[]":
- return "toDimnameListOptional"
- elif not has_default_init and default in (
- None,
- "None",
- "::std::nullopt",
- "std::nullopt",
- ):
- # If default is None: append 'Optional' to elem's unpacking method
- return (
- arg_parser_unpack_method(t.elem, None, None, symint=symint) + "Optional"
- )
- else:
- # Otherwise, load as underlying type with default
- return arg_parser_unpack_method(
- t.elem, default, default_init, symint=symint
- )
- elif isinstance(t, ListType):
- if str(t.elem) == "Tensor":
- # accept and use definite size
- return f"tensorlist_n<{t.size}>" if t.size is not None else "tensorlist"
- elif str(t.elem) == "Tensor?":
- return "list_of_optional_tensors"
- elif str(t.elem) == "Dimname":
- # accept definite size
- return "dimnamelist"
- elif str(t.elem) == "int":
- # accept definite size
- return "intlist"
- elif str(t.elem) == "float":
- return "doublelist"
- elif str(t.elem) == "SymInt":
- # accept definite size
- return "symintlist" if symint else "intlist"
- elif str(t.elem) == "Scalar":
- return "scalarlist"
- raise RuntimeError(f"type '{t}' is not supported by PythonArgParser")
- # Return RHS expression for python argument using PythonArgParser output.
- # e.g. for arg name 'foo', arg type 'bool', arg_index = 2, returns '_r.toBool(2)'
- def arg_parser_output_expr(
- arg_index: int, a: PythonArgument, *, symint: bool = True
- ) -> PythonArgParserOutputExpr:
- has_default = a.default_init is not None
- unpack_method = arg_parser_unpack_method(
- t=a.type, default=a.default, default_init=a.default_init, symint=symint
- )
- default = f", {a.default_init}" if has_default else ""
- expr = f"_r.{unpack_method}({arg_index}{default})"
- return PythonArgParserOutputExpr(
- name=a.name,
- expr=expr,
- index=arg_index,
- argument=a,
- )
- # Returns a map with key = arg_name and value = PythonArgParserOutputExpr.
- def arg_parser_output_exprs(
- ps: PythonSignature, f: NativeFunction, *, symint: bool = True
- ) -> dict[str, PythonArgParserOutputExpr]:
- return {
- e.name: e
- for i, a in enumerate(ps.arguments())
- for e in (arg_parser_output_expr(i, a, symint=symint),)
- }
- # argument name to type for scattered tensor options fields
- TENSOR_OPTIONS_FIELDS = {
- "dtype": "ScalarType?",
- "device": "Device?",
- "layout": "Layout?",
- "pin_memory": "bool?",
- "requires_grad": "bool?",
- }
- # bind arg parser outputs (python args) with dispatch lambda arguments (c++ args).
- def dispatch_lambda_exprs(
- ps: PythonSignature, f: NativeFunction, *, symint: bool = True
- ) -> DispatchLambdaArgumentExprs:
- # This method is to bind 'arg_parser_outputs' and 'lambda_args' by producing
- # 'inits' and 'lambda_args_exprs' for each lambda argument using arg parser
- # outputs.
- arg_parser_outputs = arg_parser_output_exprs(ps, f, symint=symint)
- lambda_args = dispatch_lambda_args(ps, f, symint=symint)
- inits: list[str] = []
- lambda_args_exprs: dict[str, str] = {}
- has_toptions = has_tensor_options(f)
- # 1. special inits/unpacking to provide binding exprs for lambda arguments.
- for a in ps.arguments(skip_tensor_options=True):
- name = a.name
- arg_parser_expr = arg_parser_outputs[a.name].expr
- if has_toptions and name == "self":
- # TODO: why this needs to be special case?
- inits.extend(
- [
- f"auto self = {arg_parser_expr};",
- ]
- )
- lambda_args_exprs[name] = name
- elif (
- isinstance(a, PythonOutArgument)
- and len(a.outputs) > 1
- and f.func.is_out_fn()
- ):
- inits.extend(
- [
- f"auto out = {arg_parser_expr};",
- ]
- )
- for i, out_arg in enumerate(a.outputs):
- lambda_args_exprs[out_arg.name] = f"out[{i}]"
- elif str(a.type) == "Dimname[]?":
- # [old codegen]
- # TODO: make this part of something more general, or get rid of it.
- # optional<ArrayRef<T>> are special. The PythonArgParser returns an
- # optional<vector<T>>, which cannot be implicitly converted to
- # optional<ArrayRef<T>>. One needs to unwrap the optional and rewrap.
- inits.extend(
- [
- f"auto __{name} = {arg_parser_expr};",
- f"::std::optional<DimnameList> {name} = __{name} ? ::std::make_optional(DimnameList(__{name}.value())) : ::std::nullopt;", # noqa: B950
- ]
- )
- lambda_args_exprs[name] = name
- else:
- # default case - directly using PythonArgParser output expr
- lambda_args_exprs[name] = arg_parser_expr
- # method's self is passed directly to python binding, rather than parsed
- if ps.method:
- lambda_args_exprs["self"] = "self"
- # 2. special packing/checking for TensorOptions.
- tensor_options_args_names = [a.name for a in ps.tensor_options_args]
- if has_toptions:
- if f.func.is_out_fn():
- raise RuntimeError(f"{f.func}: tensor options with output arg")
- for a in ps.tensor_options_args:
- if a.name not in TENSOR_OPTIONS_FIELDS:
- raise RuntimeError(
- f"{f.func}: unrecognized tensor options field '{a.name}' in python binding arguments"
- )
- if str(a.type) != TENSOR_OPTIONS_FIELDS.get(a.name):
- raise RuntimeError(
- f"{f.func}: unrecognized type '{str(a.type)}' for tensor options field '{a.name}'"
- )
- if not all(a in tensor_options_args_names for a in TENSOR_OPTIONS_FIELDS):
- raise RuntimeError(
- f"{f.func}: incomplete tensor options args: {tensor_options_args_names}"
- )
- inits.append(
- f"""\
- const auto options = TensorOptions()
- .dtype({arg_parser_outputs["dtype"].expr})
- .device({arg_parser_outputs["device"].expr})
- .layout({arg_parser_outputs["layout"].expr})
- .requires_grad({arg_parser_outputs["requires_grad"].expr})
- .pinned_memory({arg_parser_outputs["pin_memory"].expr});
- torch::utils::maybe_initialize_device(options);
- """
- )
- lambda_args_exprs["options"] = "options"
- # 3. special case - access scattered TensorOptions fields without packing
- # TODO: maybe move to the generator side as it's not related to binding.
- if not has_toptions and tensor_options_args_names:
- if "dtype" in tensor_options_args_names:
- # we're an output-arg variant, check these args against output tensor
- if not f.func.is_out_fn():
- raise RuntimeError(
- f"{f.func}: dtype in tensor_options_args without output arg, {ps} {ps.arguments}"
- )
- if not all(a in tensor_options_args_names for a in ("layout", "device")):
- raise RuntimeError(
- f"{f.func}: incomplete tensor options for output check"
- )
- inits.append(
- f"""\
- check_out_type_matches({arg_parser_outputs["out"].expr}, {arg_parser_outputs["dtype"].expr},
- {arg_parser_outputs["dtype"].is_none_expr}, {arg_parser_outputs["layout"].expr},
- {arg_parser_outputs["device"].expr}, {arg_parser_outputs["device"].is_none_expr});
- """
- )
- # we'll set requires_grad on outgoing tensor
- if "requires_grad" not in tensor_options_args_names:
- raise RuntimeError(
- f'{f.func}: expected "requires_grad" in tensor_options_args absent, but found [{tensor_options_args_names}]'
- )
- return DispatchLambdaArgumentExprs(
- exprs=tuple(lambda_args_exprs[a.name] for a in lambda_args),
- inits=inits,
- )
|