| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067 |
- from __future__ import annotations
- import argparse
- import functools
- import json
- import keyword
- import os
- from collections import defaultdict, namedtuple, OrderedDict
- from dataclasses import dataclass, field
- from pathlib import Path
- from typing import Any, Literal, TYPE_CHECKING, TypeVar
- from typing_extensions import assert_never
- import yaml
- import torchgen.api.dispatcher as dispatcher
- import torchgen.api.meta as meta
- import torchgen.api.native as native
- import torchgen.api.structured as structured
- import torchgen.dest as dest
- from torchgen.api import cpp
- from torchgen.api.translate import translate
- from torchgen.api.types import (
- Binding,
- CppSignature,
- CppSignatureGroup,
- DispatcherSignature,
- NamedCType,
- NativeSignature,
- SpecialArgName,
- )
- from torchgen.context import (
- method_with_native_function,
- native_function_manager,
- with_native_function,
- with_native_function_and_indices,
- )
- from torchgen.gen_aoti_c_shim import (
- gen_aoti_c_shim_files,
- gen_static_dispatch_backend_call_signature,
- )
- from torchgen.gen_functionalization_type import (
- gen_functionalization_definition,
- gen_functionalization_registration,
- gen_functionalization_view_inverse_declaration,
- gen_functionalization_view_meta_classes_decl,
- gen_functionalization_view_meta_classes_impl,
- GenCompositeViewCopyKernel,
- )
- from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
- from torchgen.model import (
- Argument,
- BackendIndex,
- BackendMetadata,
- BaseOperatorName,
- DEFAULT_KERNEL_NAMESPACE,
- dispatch_device_map,
- DispatchKey,
- FRAGMENT_NAMESPACES,
- FunctionSchema,
- is_cuda_dispatch_key,
- is_generic_dispatch_key,
- is_ufunc_dispatch_key,
- is_xpu_dispatch_key,
- Location,
- NativeFunction,
- NativeFunctionsGroup,
- NativeFunctionsViewGroup,
- OperatorName,
- OptionalType,
- SchemaKind,
- SelfArgument,
- STRUCTURED_DISPATCH_KEYS,
- TensorOptionsArguments,
- Type,
- Variant,
- ViewSchemaKind,
- )
- from torchgen.native_function_generation import (
- add_generated_native_functions,
- gen_composite_functional_kernel,
- gen_composite_out_kernel,
- pre_group_native_functions,
- )
- from torchgen.selective_build.selector import SelectiveBuilder
- from torchgen.utils import (
- concatMap,
- context,
- FileManager,
- make_file_manager,
- mapMaybe,
- NamespaceHelper,
- Target,
- )
- from torchgen.yaml_utils import YamlDumper, YamlLoader
- if TYPE_CHECKING:
- from collections.abc import Callable, Sequence
- T = TypeVar("T")
- # Welcome to the ATen code generator v2! The ATen code generator is
- # responsible for parsing native_functions.yaml and then generating
- # various generated files (e.g., TypeDefault.cpp) based on the operators
- # defined in this file. This means that the code generator knows how to
- # parse function schema, and then translate this into various C++ types
- # and boilerplate code.
- #
- # Some things to know about this file when you modify it:
- #
- # - This file has STRICT mypy typechecking. Typecheck it with
- # `mypy --config mypy-strict.ini` in the root source directory
- #
- # - Most of the heavy lifting lives in external modules:
- # - 'model' has the data model for native_functions.yaml. The classes
- # in those file represent what you see when you look at
- # a native_functions.yaml
- # - 'api' has conversions for how to translate JIT schema into
- # the various C++ APIs that the codegen interacts with. There
- # are in fact THREE different C++ APIs: the public C++ API,
- # the dispatcher API, and the legacy dispatcher API. See each
- # of these respective files for more information
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # HELPER FUNCTIONS
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- # A custom loader for YAML to let us also keep track of line numbers
- # of each entry in the YAML file
- class LineLoader(YamlLoader):
- def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
- mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
- # Add 1 so line numbering starts at 1
- mapping["__line__"] = node.start_mark.line + 1
- return mapping
- # Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
- ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
- _GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
- _GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
- def file_manager_from_dispatch_key(
- dispatch_key: DispatchKey,
- device_fms: dict[str, FileManager],
- default_fm: FileManager,
- ) -> FileManager:
- fm = device_fms.get(
- next(
- (
- device
- for check, device in dispatch_device_map.items()
- if check(dispatch_key)
- ),
- "",
- ),
- default_fm,
- )
- return fm
- def parse_native_yaml_struct(
- es: object,
- valid_tags: set[str],
- ignore_keys: set[DispatchKey] | None = None,
- path: str = "<stdin>",
- skip_native_fns_gen: bool = False,
- ) -> ParsedYaml:
- if not isinstance(es, list):
- raise AssertionError(f"Expected 'es' to be a list, but got {type(es)}")
- rs: list[NativeFunction] = []
- bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict)
- for e in es:
- if not isinstance(e, dict):
- raise AssertionError(f"Expected to be dict: {e}")
- if not isinstance(e.get("__line__"), int):
- raise AssertionError(f"Expected '__line__' to be int: {e}")
- loc = Location(path, e["__line__"])
- funcs = e.get("func")
- if funcs is None:
- raise AssertionError(f"Missed 'func' in {e}")
- with context(lambda: f"in {loc}:\n {funcs}"):
- func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
- rs.append(func)
- BackendIndex.grow_index(bs, m)
- error_check_native_functions(rs)
- # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
- indices: dict[DispatchKey, BackendIndex] = defaultdict(
- lambda: BackendIndex(
- dispatch_key=DispatchKey.Undefined,
- use_out_as_primary=True,
- external=False,
- device_guard=False,
- # I'm actually not sure about this; undefined could be hit on
- # empty TensorList, hypothetically that could have sizes in it
- index={},
- )
- )
- if not skip_native_fns_gen:
- add_generated_native_functions(rs, bs)
- for k, v in bs.items():
- # All structured in-tree operators are implemented in terms of their out operator.
- indices[k] = BackendIndex(
- dispatch_key=k,
- use_out_as_primary=True,
- external=False,
- # Only cuda-like devices in tree require device guards
- device_guard=is_cuda_dispatch_key(k) or is_xpu_dispatch_key(k),
- index=v,
- )
- return ParsedYaml(rs, indices)
- def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
- if not isinstance(es, list):
- raise AssertionError(f"Expected 'es' to be a list, but got {type(es)}")
- rs: set[str] = set()
- for e in es:
- if not isinstance(e.get("__line__"), int):
- raise AssertionError(f"Expected '__line__' to be int: {e}")
- loc = Location(path, e["__line__"])
- tags = e.get("tag")
- with context(lambda: f"in {loc}:\n {tags}"):
- e_i = e.copy()
- name = e_i.pop("tag")
- desc = e_i.pop("desc", "")
- # ensure that each tag has a non-empty description
- if desc == "":
- raise AssertionError(f"Tag '{name}' must have a non-empty description")
- rs.add(name)
- return rs
- @functools.cache
- def parse_tags_yaml(path: str) -> set[str]:
- global _GLOBAL_PARSE_TAGS_YAML_CACHE
- if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
- with open(path) as f:
- es = yaml.load(f, Loader=LineLoader)
- _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
- return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
- def parse_native_yaml(
- path: str,
- tags_yaml_path: str,
- ignore_keys: set[DispatchKey] | None = None,
- *,
- skip_native_fns_gen: bool = False,
- loaded_yaml: object | None = None,
- ) -> ParsedYaml:
- global _GLOBAL_PARSE_NATIVE_YAML_CACHE
- if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
- valid_tags = parse_tags_yaml(tags_yaml_path)
- # if a loaded yaml is provided, use that instead of reading from path
- if loaded_yaml is None:
- with open(path) as f:
- es = yaml.load(f, Loader=LineLoader)
- else:
- es = loaded_yaml
- _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(
- es,
- valid_tags,
- ignore_keys,
- path=path,
- skip_native_fns_gen=skip_native_fns_gen,
- )
- return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
- # Some assertions are already performed during parsing, but those are only within a single NativeFunction.
- # Assertions here are meant to be performed across NativeFunctions.
- def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
- func_map: dict[OperatorName, NativeFunction] = {}
- base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
- for f in funcs:
- func_map[f.func.name] = f
- base_func_map[f.func.name.name].append(f)
- for f in funcs:
- if f.structured_delegate is not None:
- delegate_func = func_map.get(f.structured_delegate)
- if delegate_func is None:
- raise AssertionError(
- f"{f.func.name} is marked as a structured_delegate pointing to "
- f"{f.structured_delegate}, but {f.structured_delegate} is missing."
- )
- if not delegate_func.structured:
- raise AssertionError(
- f"{f.func.name} is marked as a structured_delegate pointing to "
- f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
- f"Consider adding 'structured=True' to the delegated operator"
- )
- # Check for reserved Python keywords
- PYTHON_RESERVED_KEYWORDS = set(keyword.kwlist)
- # List of pre-existing operators that are known to have reserved keywords
- # Exclusion list is used to suppress the assertion for these operators
- EXCLUSION_LIST = {
- ("_has_compatible_shallow_copy_type", "from"),
- ("random_.from", "from"),
- ("uniform_", "from"),
- }
- for arg in f.func.arguments.flat_all:
- if arg.name in PYTHON_RESERVED_KEYWORDS:
- if (str(f.func.name), arg.name) not in EXCLUSION_LIST:
- raise AssertionError(
- f"Argument name '{arg.name}' in function '{f.func.name}' is a reserved Python keyword."
- )
- # See Note [resize_ in Functionalization]
- # resize_() is technically an inplace view op (and therefore needs the tag),
- # but it would be overkill to add a true "view" variant of resize.
- # Instead, resize_() gets special treatment in functionalization,
- # and we have a resize() op that is non-aliasing + functional.
- if (
- "inplace_view" in f.tags
- and str(f.func.name) != "resize_"
- and str(f.func.name) != "resize_as_"
- and str(f.func.name.name) != "set_"
- ):
- base_name = f.func.name.name
- if not base_name.inplace:
- raise AssertionError(
- f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming "
- "convention for inplace ops - the codegen expects the base name to have a trailing underscore."
- )
- out_of_place_base_name = BaseOperatorName(
- base_name.base, False, base_name.dunder_method
- )
- if len(base_func_map[out_of_place_base_name]) == 0:
- raise AssertionError(
- f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding "
- f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one."
- )
- def cpp_string(s: str) -> str:
- """Convert a python string into a c++ string literal"""
- s = s.replace("\\", "\\\\")
- s = s.replace('"', '\\"')
- s = s.replace("\a", "\\a")
- s = s.replace("\b", "\\b")
- s = s.replace("\f", "\\f")
- s = s.replace("\n", "\\n")
- s = s.replace("\v", "\\v")
- s = s.replace("\t", "\\t")
- return f'"{s}"'
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # C++ CODE GENERATION
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- # Most functions in this section are curried: they consist of a function
- # that takes some parameters (e.g., what is to be generated) which itself
- # returns a function that actually maps NativeFunction to the code
- # to be generated. This pattern makes it convenient to use map, concatMap
- # and similar functional combinators.
- def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
- if len(backends) == 0:
- return []
- else:
- return [backend.dispatch_key for backend in backends] + [
- DispatchKey.CompositeImplicitAutograd,
- DispatchKey.CompositeImplicitAutogradNestedTensor,
- DispatchKey.CompositeExplicitAutograd,
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- ]
- def get_static_dispatch_backend(
- f: NativeFunction, backend_index: BackendIndex
- ) -> DispatchKey | None:
- if f.structured_delegate is not None or backend_index.has_kernel(f):
- # TODO: for ops with structured_delegate it should check the dispatch table of
- # the out variant instead. For now, these structured ops all have CPU/CUDA kernels
- # so we always dispatch to the `backend`, but this could be wrong when we
- # migrate math/default_backend ops to use structured delegate.
- return backend_index.dispatch_key
- elif f.has_composite_explicit_autograd_kernel:
- return DispatchKey.CompositeExplicitAutograd
- elif f.has_composite_explicit_autograd_non_functional_kernel:
- return DispatchKey.CompositeExplicitAutogradNonFunctional
- elif f.has_composite_implicit_autograd_kernel:
- return DispatchKey.CompositeImplicitAutograd
- elif f.has_composite_implicit_autograd_nested_tensor_kernel:
- return DispatchKey.CompositeImplicitAutogradNestedTensor
- return None
- def static_dispatch_ops_header(
- f: NativeFunction, backend_index: list[BackendIndex]
- ) -> str | None:
- if backend_index is None or f.manual_kernel_registration:
- return None
- output = []
- for index in backend_index:
- dispatch_key = get_static_dispatch_backend(f, index)
- if dispatch_key is not None:
- output.append(
- f"#include <ATen/ops/{f.root_name}_{dispatch_key.lower()}_dispatch.h>"
- )
- return "\n".join(output)
- def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
- return [
- f"#include <ATen/{dispatch_key}Functions.h>"
- for dispatch_key in static_dispatch_keys(backends)
- ]
- # Translates arguments of `sig` to CppSignature bindings.
- # Note that we have a special case for `memory_format` argument and this case is not covered by
- # tools.codegen.api.translate() yet as its application is limited to static dispatch.
- def translate_args(
- sig: CppSignature | DispatcherSignature,
- cpp_sig: CppSignature,
- ) -> str:
- # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
- def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]:
- output_bindings: list[Binding] = []
- for binding in input_bindings:
- if binding.name == "memory_format":
- spl_mem_format_binding = Binding(
- nctype=NamedCType(
- SpecialArgName.possibly_redundant_memory_format,
- binding.nctype.type,
- ),
- name=binding.name,
- default=binding.default,
- argument=binding.argument,
- )
- output_bindings.append(spl_mem_format_binding)
- else:
- output_bindings.append(binding)
- return output_bindings
- src_bindings = list(sig.arguments())
- goal_bindings = list(cpp_sig.arguments())
- # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
- # get memory_format bindings of dispatcher signature to have the same NCType as well
- for arg in goal_bindings:
- if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
- src_bindings = add_spl_memory_format_binding(src_bindings)
- break
- exprs = translate(src_bindings, goal_bindings)
- return ", ".join(a.expr for a in exprs)
- def generate_static_dispatch_backend_call(
- sig: CppSignature | DispatcherSignature,
- f: NativeFunction,
- backend_index: BackendIndex,
- ) -> str:
- cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
- name = cpp_sig.name()
- exprs = translate_args(sig, cpp_sig)
- backend_metadata = backend_index.get_kernel(f)
- kernel_ns = (
- backend_metadata.cpp_namespace
- if backend_metadata and backend_metadata.cpp_namespace
- else DEFAULT_KERNEL_NAMESPACE
- )
- ns = kernel_ns.replace("::native", "")
- return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});"
- def generate_static_dispatch_fallback_call(
- sig: CppSignature | DispatcherSignature,
- f: NativeFunction,
- backend_indices: list[BackendIndex],
- ) -> str:
- cpp_sigs = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=False
- )
- if sig.symint and f.func.has_symint():
- cpp_sig = cpp_sigs.symint_signature
- else:
- cpp_sig = cpp_sigs.signature
- if cpp_sig is None:
- raise AssertionError("Expected cpp_sig to be non-None")
- name = cpp_sig.name()
- exprs = translate_args(sig, cpp_sig)
- ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
- if f.has_composite_explicit_autograd_kernel:
- return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
- elif f.has_composite_explicit_autograd_non_functional_kernel:
- return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
- elif f.has_composite_implicit_autograd_kernel:
- return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
- elif f.has_composite_implicit_autograd_nested_tensor_kernel:
- return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
- else:
- return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
- {", ".join([str(index.dispatch_key) for index in backend_indices])} ");"""
- def static_dispatch(
- sig: CppSignature | DispatcherSignature,
- f: NativeFunction,
- backend_indices: list[BackendIndex],
- ) -> str:
- """
- For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
- backends exist, fallback to static dispatch by determining dispatch key from inputs.
- Arguments:
- sig: A CppSignature or DispatcherSignature for this native function we want to use.
- f: NativeFunction to generate static dispatch.
- backend_indices: All available backends.
- Return:
- C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);"
- """
- if len(backend_indices) == 0 or f.manual_kernel_registration:
- return ""
- keys = [
- b
- for b in backend_indices
- if b.has_kernel(f)
- or (
- f.structured_delegate is not None
- and b.dispatch_key in STRUCTURED_DISPATCH_KEYS
- )
- ]
- if len(keys) == 1:
- return generate_static_dispatch_backend_call(sig, f, keys[0])
- elif len(keys) == 0:
- return generate_static_dispatch_fallback_call(sig, f, backend_indices)
- native_tensor_args = [
- a.name
- for a in sig.arguments()
- if isinstance(a.argument, SelfArgument)
- or isinstance(a.argument, Argument)
- and a.argument.type.is_tensor_like()
- ]
- tensor_args = ", ".join(native_tensor_args)
- tensor_opts = f.func.arguments.tensor_options
- stmts = []
- subexprs: list[str] = []
- if tensor_opts is not None:
- subexprs.append(
- "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
- )
- if tensor_args != "":
- subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
- stmts.append(f"""DispatchKeySet _dk_set = {" | ".join(subexprs)};""")
- stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
- dispatch_code = []
- for index in keys:
- dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
- dispatch_code.append(
- f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
- )
- fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
- connector = "\n\t\t"
- return f"""
- {connector.join(stmts)}
- switch (_dk) {{
- {connector.join(dispatch_code)}
- default:
- {fallback}
- }}
- """
- # Generates RegisterSchema.cpp. Depending on the selector, either
- # all schemas are registered, or only some are (in the case of
- # selective build)
- @dataclass(frozen=True)
- class RegisterSchema:
- selector: SelectiveBuilder
- known_tags: dict[str, int] = field(default_factory=dict)
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str | None:
- if not self.selector.is_native_function_selected(f):
- return None
- tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
- if tags == "{}":
- return f"m.def({cpp_string(str(f.func))}, {{}});\n"
- maybe_tags = ""
- if tags not in self.known_tags:
- idx = len(self.known_tags)
- self.known_tags[tags] = idx
- maybe_tags = f"const std::vector<at::Tag> tags_{idx} = {tags};\n"
- return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n"
- # Generates Operators.h and Operators.cpp.
- # These provide macros that, given an operator and overload name, allow users
- # to access an "un-overloaded" function version of the operator. This
- # is useful for extension writers who want to (1) want to decltype the operator
- # and (2) don't want to worry about method-only operators.
- @dataclass(frozen=True)
- class ComputeOperators:
- target: Literal[Target.DECLARATION, Target.DEFINITION]
- static_dispatch_backend_indices: list[BackendIndex]
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str:
- sig = DispatcherSignature.from_schema(f.func)
- name = f.func.name.unambiguous_name()
- if self.target is Target.DECLARATION:
- # Note [The ATen Operators API]
- # The ATen Operators API lives in the at::_ops namespace, and contains compile-time
- # metadata about each operator + entry points into the Dispatcher.
- # The C++ function, method, and redispatch API's are all implemented as wrappers
- # into various bits of the structs defined here.
- #
- # Important characteristics about the Operators API:
- # (1) It follows the Dispatcher API.
- # This is kind of necessary to avoid overhead.
- # For example: if it followed the C++ API, then all of the faithful C++ factory functions
- # would need to wrap their arguments into TensorOptions only to unwrap them again.
- # (2) Overload names are disambiguated.
- # This is helpful for pytorch extenders who would like to decltype() an aten operator,
- # that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
- # (3) No argument defaulting is allowed.
- # This is more of an implementation detail to avoid #include cycles,
- # since TensorBody.h (which defines the Tensor class) needs to include this file.
- # (4) manual_cpp_bindings and faithful names are not included in the API.
- # This applies to stuff like __dispatch__is_complex(), and add_outf().
- # These aren't "real aten ops", they're just additional functions provided by the C++ API.
- # They're implemented as wrappers in Functions.h that call into the actual operators
- # defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
- # This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
- return f"""
- struct TORCH_API {name} {{
- using schema = {sig.type()};
- using ptr_schema = schema*;
- // See Note [static constexpr char* members for windows NVCC]
- static constexpr const char* name = "aten::{f.func.name.name}";
- static constexpr const char* overload_name = "{f.func.name.overload_name}";
- static constexpr const char* schema_str = {cpp_string(str(f.func))};
- static {sig.defn(name="call", is_redispatching_fn=False)};
- static {sig.defn(name="redispatch", is_redispatching_fn=True)};
- }};"""
- elif self.target is Target.DEFINITION:
- defns = f"""
- // aten::{f.func}
- static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
- return c10::Dispatcher::singleton()
- .findSchemaOrThrow({name}::name, {name}::overload_name)
- .typed<{name}::schema>();
- }}
- """
- for is_redispatching_fn in [False, True]:
- if is_redispatching_fn:
- dispatcher_exprs_str = ", ".join(
- ["dispatchKeySet"] + [a.name for a in sig.arguments()]
- )
- method_base = "redispatch"
- else:
- dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
- method_base = "call"
- dispatcher_call = method_base
- method_name = f"{name}::{method_base}"
- fn_body = f"""
- static auto op = create_{name}_typed_handle();
- return op.{dispatcher_call}({dispatcher_exprs_str});"""
- if (
- not is_redispatching_fn
- and len(self.static_dispatch_backend_indices) > 0
- ):
- # call() should go through static dispatch
- fn_body = static_dispatch(
- sig, f, backend_indices=self.static_dispatch_backend_indices
- )
- defns += f"""
- // aten::{f.func}
- {sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
- {fn_body}
- }}
- """
- return defns
- else:
- assert_never(self.target)
- # Generates Functions.h, which provides the functional public C++ API,
- # and the scaffolding to call into the dispatcher from these functions.
- @dataclass(frozen=True)
- class ComputeFunction:
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str | None:
- sig_group = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=f.manual_cpp_binding
- )
- has_symint = f.func.has_symint()
- result = ""
- for sig in sig_group.signatures():
- # See Note [The ATen Operators API]
- target_sig = DispatcherSignature.from_schema(f.func)
- exprs = translate(sig.arguments(), target_sig.arguments())
- exprs_str = ", ".join([e.expr for e in exprs])
- if sig.symint:
- intlike_t = "c10::SymInt"
- else:
- intlike_t = "int64_t"
- if Variant.function in f.variants:
- result += f"""
- // aten::{f.func}
- inline {sig.decl()} {{
- return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
- }}"""
- # The template function can be used from template situations
- # where you want to switch between the symint or not version
- # depending on a template argument
- #
- # NB: we ALWAYS generate this even for methods. But we put it in
- # this header so it can take advantage of per-op headers
- if has_symint:
- result += f"""
- namespace symint {{
- template <typename T, typename = std::enable_if_t<std::is_same_v<T, {intlike_t}>>>
- {sig.decl(suppress_symint_suffix=True)} {{
- return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
- }}
- }}
- """
- return result
- # Generates TensorBody.h. This file provides the object-oriented (method-based)
- # public C++ API, and the scaffolding to call into the dispatcher from these functions.
- @dataclass(frozen=True)
- class ComputeTensorMethod:
- target: Literal[Target.DECLARATION, Target.DEFINITION]
- static_dispatch_backend_indices: list[BackendIndex]
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str | None:
- if Variant.method not in f.variants:
- return None
- if f.func.is_out_fn():
- raise AssertionError(f"Method variant cannot be an out function: {f.func}")
- if f.func.arguments.self_arg is None:
- raise AssertionError(f"Method variant must have self_arg: {f.func}")
- sig_group = CppSignatureGroup.from_native_function(
- f, method=True, fallback_binding=f.manual_cpp_binding
- )
- if self.target is Target.DECLARATION:
- result = ""
- for sig in sig_group.signatures():
- result += f"{sig.decl()} const;\n"
- return result
- if self.target is not Target.DEFINITION:
- assert_never(self.target)
- result = ""
- for sig in sig_group.signatures():
- target_sig = DispatcherSignature.from_schema(f.func)
- exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
- exprs_str = ", ".join([e.expr for e in exprs])
- result += f"""
- // aten::{f.func}
- inline {sig.defn(prefix="Tensor::")} const {{
- return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
- }}
- """
- return result
- # Generates RedispatchFunctions.h.
- # This is similar to the C++ API defined in Functions.h, but provides access
- # to the dispatcher's redispatch API.
- @dataclass(frozen=True)
- class ComputeRedispatchFunction:
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str | None:
- # We unconditionally generate function variants of the redispatch API.
- # This is mainly because we can namespace functions separately, but not methods,
- sig_group = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=f.manual_cpp_binding
- )
- result = ""
- for sig in sig_group.signatures():
- target_sig = DispatcherSignature.from_schema(f.func)
- exprs = translate(sig.arguments(), target_sig.arguments())
- exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
- result += f"""
- // aten::{f.func}
- inline {sig.decl(is_redispatching_fn=True)} {{
- return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
- }}
- """
- return result
- # Generates ATenOpList.cpp, a runtime accessible list of all aten
- # operators.
- # TODO: This was historically used to help some JIT interop code
- # figure out whether or not to treat aten namespace'd operators
- # one way or another, we should reevaluate if this is actually needed.
- @with_native_function
- def compute_aten_op(f: NativeFunction) -> str:
- return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
- # Generates MetaFunctions.h
- def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None:
- if not g.structured:
- return None
- with native_function_manager(g.out):
- name = meta.name(g)
- args = structured.meta_arguments(g)
- args_str = ", ".join(a.decl() for a in args)
- parent_class = g.out.structured_inherits
- if parent_class is None:
- parent_class = "at::impl::MetaBase"
- meta_return = "void"
- precomputed = g.out.precomputed if g.structured else None
- if precomputed:
- # Generate the template declaration with one bool parameter for each
- # precomputed element. Each parameter is true if the corresponding (in
- # terms of position) precomputed element has been set.
- precomputed_values = [*precomputed.replace.values(), precomputed.add]
- precomputed_elements = [
- elem for replace_list in precomputed_values for elem in replace_list
- ]
- precomputed_template_parameters = [
- elem.name.upper() for elem in precomputed_elements
- ]
- precomputed_template_params_str = ", ".join(
- f"bool {param} = false" for param in precomputed_template_parameters
- )
- precompute_template_decl = f"template <{precomputed_template_params_str}>"
- # Generate a string containing declarations of all precomputed elements.
- precomputed_elements_with_cpp_types = [
- structured.argument_type(elem, binds=elem.name)
- for elem in precomputed_elements
- ]
- precomputed_elements_decl = ";\n".join(
- f"{elem.cpp_type(strip_ref=True)} {elem.name}"
- for elem in precomputed_elements_with_cpp_types
- )
- # Generate "setter" methods for each precomputed element. Each method will return
- # a new instance of precompute_out with the template parameter that corresponds to
- # the member set by the method to true (to indicate that it has been set).
- setter_methods = []
- for i, elem in enumerate(precomputed_elements):
- # Generate the signature. The return type will be the same
- # as the type of `this` but with the template parameter
- # corresponding to the element set by this method set to true.
- # The assert generated below will ensure that this template
- # parameter is false on the type of `this`.
- return_ty_templates = ", ".join(
- precomputed_template_parameters[:i]
- + ["true"]
- + precomputed_template_parameters[i + 1 :]
- )
- return_ty = f"precompute_out<{return_ty_templates}>"
- elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(
- strip_ref=True
- )
- signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
- # Generate an assert which checks that the
- # template parameter corresponding to the precomputed
- # element that is set by this method is false on the
- # class corresponding to the object that `this` points to.
- # This ensures that each element can be set only once.
- assert_msg = f'"{elem.name} already set"'
- assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
- # Generate the new object construction block. All state
- # except the element that this method sets is copied from the
- # object that `this` points to. The value for the element that
- # the method sets is taken from a method parameter.
- construction_stmts = []
- construction_stmts.append(f"{return_ty} ret;")
- for j, elem in enumerate(precomputed_elements):
- if i == j:
- construction_stmts.append(f"ret.{elem.name} = value;")
- else:
- construction_stmts.append(
- f"ret.{elem.name} = this->{elem.name};"
- )
- construction_stmts.append("return ret;")
- construction_block = "\n".join(construction_stmts)
- setter_methods.append(
- f"""
- {signature} {{
- {assert_stmt}
- {construction_block}
- }}
- """
- )
- setter_methods_decl = "\n".join(setter_methods)
- # Meta should return an instance of the struct containing the precomputed elements.
- meta_return_template_params = ", ".join(
- ["true"] * len(precomputed_template_parameters)
- )
- # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
- # type (which has a variable number of template parameters).
- meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
- meta_return = "meta_return_ty"
- precomputed_decl = f"""
- {precompute_template_decl}
- struct TORCH_API precompute_out {{
- {setter_methods_decl}
- {precomputed_elements_decl};
- }};"""
- else:
- meta_return_typedef = ""
- precomputed_decl = ""
- return f"""\
- struct TORCH_API structured_{name} : public {parent_class} {{
- {precomputed_decl}
- {meta_return_typedef}
- {meta_return} meta({args_str});
- }};
- """
- def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool:
- name = str(f.func.name.name)
- if name.endswith("_like") or name.startswith("new_"):
- return False
- if f.func.arguments.tensor_options is None:
- return False
- return selector.is_native_function_selected(f)
- # Generates RegisterBackendSelect.cpp, a series of kernels which provide
- # specialized computation of dispatch key for operator signatures which cannot
- # be easily done automatically using templating.
- @dataclass(frozen=True)
- class ComputeBackendSelect:
- target: Literal[Target.DEFINITION, Target.REGISTRATION]
- # Selector object to determine which operators to generate
- # registration code for.
- selector: SelectiveBuilder
- @method_with_native_function
- def __call__(self, f: NativeFunction) -> str | None:
- if not needs_backend_select(f, self.selector):
- return None
- name = native.name(f.func)
- # BackendSelect can go to Meta, so it must preserve symints
- native_sig = NativeSignature(f.func, symint=True)
- native_tensor_args = [
- a
- for a in native_sig.arguments()
- if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
- ]
- dispatcher_sig = DispatcherSignature.from_schema(f.func)
- sig: NativeSignature | DispatcherSignature
- sig = dispatcher_sig
- dispatcher_exprs = dispatcher_sig.exprs()
- dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
- if self.target is Target.DEFINITION:
- # I don't think there's actually a good reason to generate
- # these two cases differently
- # The first case could probably be improved though- it calls computeDispatchKeySet(),
- # which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
- if native_tensor_args:
- if not f.func.arguments.has_tensor_arg():
- raise AssertionError(
- f"Expected function to have tensor args: {f.func}"
- )
- tensor_args = ", ".join(a.name for a in native_tensor_args)
- compute_dk = f"""\
- DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
- DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
- DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
- else:
- if f.func.arguments.has_tensor_arg():
- raise AssertionError(
- f"Expected function to not have tensor args: {f.func}"
- )
- compute_dk = (
- f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
- )
- return f"""\
- // aten::{f.func}
- C10_ALWAYS_INLINE
- {sig.defn(name)} {{
- {compute_dk}
- return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
- _dk, {", ".join(a.expr for a in dispatcher_exprs)});
- }}
- """
- elif self.target is Target.REGISTRATION:
- return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
- else:
- assert_never(self.target)
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # YAML CODE GENERATION
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- def format_yaml(data: object) -> str:
- # Ignore alias in Dumper
- YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment]
- # Support serializing OrderedDict
- def dict_representer(dumper: Any, data: Any) -> Any:
- return dumper.represent_dict(data.items())
- YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call]
- # Some yaml parsers (e.g. Haskell's) don't understand line breaks.
- # width=1e9 turns off optional line breaks and improves
- # the portability of the outputted yaml.
- return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return, call-overload]
- # For some reason, some defaults we write to YAML are written as native
- # YAML objects, rather than doing them uniformly as strings. This
- # function detects those cases and converts them into native Python
- # objects.
- def pythonify_default(s: str) -> object:
- if s == "true":
- return True
- elif s == "false":
- return False
- try:
- return int(s)
- except ValueError:
- try:
- return float(s)
- except ValueError:
- return s
- # What is a dynamic type? Over time, the semantic meaning of
- # dynamic type has degraded to meaninglessness (in the old days,
- # it captured dtype-ness of types, but that has gone away with
- # the removal of TH). These days, it's mostly the same thing as
- # the C++ API argument type, except that Tensor and Tensor?
- # arguments simply present as Tensor.
- #
- # TODO: Get rid of dynamic_type, after getting tools/autograd
- # to use the new codegen framework
- def dynamic_type(t: Type) -> str:
- if isinstance(t, OptionalType):
- return dynamic_type(t.elem)
- # Note we don't use t.is_tensor_like() here because it would
- # also include Tensor[]
- if str(t) == "Tensor":
- return "at::Tensor"
- # This is a legacy concept, so never report SymInt
- return cpp.argumenttype_type(
- t, mutable=False, binds="__placeholder__", symint=False
- ).cpp_type()
- def compute_method_of_yaml(variants: set[Variant]) -> list[str]:
- # This is written out explicitly to ensure that Tensor and
- # namespace are put into the list in the right order
- method_of = ["Type"]
- if Variant.method in variants:
- method_of.append("Tensor")
- if Variant.function in variants:
- method_of.append("namespace")
- return method_of
- def compute_returns_yaml(
- f: NativeFunction,
- ) -> tuple[list[dict[str, str]], dict[str, str]]:
- # Note [name and field_name]
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~
- # To understand name_to_field_name, we must first talk about this
- # schema:
- #
- # lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
- #
- # There is something very odd about this schema: it is an out
- # variant of the function (that is to say, it will convert into
- # at::lstsq_out() in the C++ API), but the names of the output
- # return arguments don't match the keyword argument names of
- # the inputs. It TURNS OUT that in this situation, the historical
- # Declarations.yaml we want to output is this (abbreviated to
- # only show relevant fields):
- #
- # arguments:
- # ...
- # - field_name: solution
- # name: X
- # - field_name: QR
- # name: qr
- # ...
- #
- # returns:
- # - field_name: solution
- # name: X
- # - field_name: QR
- # name: qr
- #
- # The name of the return fields is stored in 'field_name', and the
- # name of the arguments is stored in 'name'. So when we process
- # arguments, we need a way to get at the corresponding return. At
- # the moment, this is most conveniently done by constructing a
- # mapping from name (the argument concept) to field_name (the
- # return concept) while processing return arguments, since we don't
- # directly maintain this correspondence in the modeling of function
- # schema itself.
- #
- # See also https://github.com/pytorch/pytorch/issues/43114
- name_to_field_name: dict[str, str] = {}
- # Compute the returns field of the YAML entry
- names = cpp.return_names(f)
- returns = []
- for i, (r, name) in enumerate(zip(f.func.returns, names)):
- ret = {
- "dynamic_type": dynamic_type(r.type),
- "name": name,
- # legacy, report ints
- "type": cpp.return_type(r, symint=False).cpp_type(),
- }
- if r.name:
- # See Note [name and field_name]
- ret["field_name"] = r.name
- if f.func.is_out_fn():
- name_to_field_name[f.func.arguments.out[i].name] = r.name
- returns.append(ret)
- return returns, name_to_field_name
- # arguments in yaml roughly corresponds to the public C++ API
- def compute_cpp_argument_yaml(
- cpp_a: Binding,
- *,
- schema_order: bool,
- kwarg_only_set: set[str],
- out_arg_set: set[str],
- name_to_field_name: dict[str, str],
- ) -> object:
- if isinstance(cpp_a.argument, TensorOptionsArguments):
- arg: dict[str, object] = {
- "annotation": None,
- "dynamic_type": "at::TensorOptions",
- "is_nullable": False,
- "name": cpp_a.name,
- "type": cpp_a.type,
- "kwarg_only": True,
- }
- if cpp_a.default is not None:
- arg["default"] = cpp_a.default
- return arg
- elif isinstance(cpp_a.argument, SelfArgument):
- raise AssertionError
- elif isinstance(cpp_a.argument, Argument):
- return compute_argument_yaml(
- cpp_a.argument,
- schema_order=schema_order,
- kwarg_only_set=kwarg_only_set,
- out_arg_set=out_arg_set,
- name_to_field_name=name_to_field_name,
- )
- def compute_argument_yaml(
- a: Argument,
- *,
- schema_order: bool,
- kwarg_only_set: set[str],
- out_arg_set: set[str],
- name_to_field_name: dict[str, str],
- ) -> object:
- arg: dict[str, object] = {
- "annotation": str(a.annotation) if a.annotation else None,
- "dynamic_type": dynamic_type(a.type),
- "is_nullable": a.type.is_nullable(),
- "name": a.name,
- # legacy, report ints
- "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
- }
- if a.default is not None:
- arg["default"] = pythonify_default(
- cpp.default_expr(a.default, a.type, symint=False)
- )
- if a.name in kwarg_only_set:
- arg["kwarg_only"] = True
- if a.name in out_arg_set:
- arg["output"] = True
- arg["allocate"] = True
- # See Note [name and field_name]
- if a.name in name_to_field_name:
- arg["field_name"] = name_to_field_name[a.name]
- # Historically, booleans don't get their size recorded, because it
- # is already built into the cpp type (e.g., std::array<bool, 4>)
- l = a.type.is_list_like()
- if l is not None and l.size is not None and str(l.elem) != "bool":
- arg["size"] = l.size
- return arg
- @with_native_function
- def compute_declaration_yaml(f: NativeFunction) -> object:
- returns, name_to_field_name = compute_returns_yaml(f)
- # These sets are used to conveniently test if an argument is a
- # kwarg-only or out argument
- kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only}
- out_arg_set = {a.name for a in f.func.arguments.out}
- sig_group = CppSignatureGroup.from_native_function(
- f, method=False, fallback_binding=False
- )
- cpp_args = sig_group.signature.arguments()
- arguments = [
- compute_cpp_argument_yaml(
- cpp_a,
- schema_order=False,
- kwarg_only_set=kwarg_only_set,
- out_arg_set=out_arg_set,
- name_to_field_name=name_to_field_name,
- )
- for cpp_a in cpp_args
- ]
- schema_order_jit_arguments = list(f.func.schema_order_arguments())
- schema_order_arguments = [
- compute_argument_yaml(
- a,
- schema_order=True,
- kwarg_only_set=kwarg_only_set,
- out_arg_set=out_arg_set,
- name_to_field_name=name_to_field_name,
- )
- for a in schema_order_jit_arguments
- ]
- cpp_schema_order_types = [
- # NB: method here doesn't matter
- r.type
- for a in schema_order_jit_arguments
- for r in cpp.argument(
- a,
- method=False,
- cpp_no_default_args=set(),
- faithful=False,
- symint=False,
- has_tensor_options=False,
- )
- ]
- # legacy, report ints
- cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type()
- schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
- is_factory_method = (
- any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args)
- and Variant.method not in f.variants
- )
- return OrderedDict(
- [
- ("name", cpp.name(f.func)),
- ("operator_name", str(f.func.name.name)),
- ("overload_name", str(f.func.name.overload_name)),
- ("manual_kernel_registration", f.manual_kernel_registration),
- (
- "category_override",
- f.category_override if f.category_override is not None else "",
- ),
- ("schema_string", f"aten::{f.func}"),
- ("arguments", arguments),
- ("schema_order_cpp_signature", schema_order_cpp_signature),
- ("schema_order_arguments", schema_order_arguments),
- ("method_of", compute_method_of_yaml(f.variants)),
- ("mode", "native"),
- ("python_module", "" if f.python_module is None else f.python_module),
- ("returns", returns),
- ("inplace", f.func.name.name.inplace),
- ("is_factory_method", is_factory_method),
- ("abstract", f.is_abstract),
- ("device_guard", f.device_guard),
- ("with_gil", False),
- ("deprecated", False),
- ("has_math_kernel", f.has_composite_implicit_autograd_kernel),
- ]
- )
- # See Note [Auto generated composite kernels]
- def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
- return (f.structured or f.structured_delegate is not None) and (
- f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace
- )
- @with_native_function_and_indices
- def compute_registration_declarations(
- f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex]
- ) -> str:
- name = dispatcher.name(f.func)
- returns_type = dispatcher.returns_type(f.func.returns).cpp_type()
- args = dispatcher.arguments(f.func)
- args_str = ", ".join(a.no_default().decl() for a in args)
- comment_data: dict[str, str] = {
- "schema": f"aten::{f.func}",
- # TODO: What exactly is the semantics of the 'dispatch' field?
- "dispatch": str(
- {k for k, v in backend_indices.items() if v.has_kernel(f)}
- != {DispatchKey.CompositeImplicitAutograd}
- and {k for k, v in backend_indices.items() if v.has_kernel(f)}
- != {
- DispatchKey.CompositeImplicitAutograd,
- DispatchKey.CompositeImplicitAutogradNestedTensor,
- }
- ),
- "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
- }
- return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
- """
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- #
- # RUN IT ALL
- #
- # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
- def get_custom_build_selector(
- provided_op_registration_allowlist: list[str] | None,
- op_selection_yaml_path: str | None,
- ) -> SelectiveBuilder:
- if (
- provided_op_registration_allowlist is not None
- and op_selection_yaml_path is not None
- ):
- raise AssertionError(
- "Both provided_op_registration_allowlist and op_selection_yaml_path "
- "can NOT be provided at the same time."
- )
- op_registration_allowlist: set[str] | None = None
- if provided_op_registration_allowlist is not None:
- op_registration_allowlist = set(provided_op_registration_allowlist)
- if op_registration_allowlist is not None:
- selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
- op_registration_allowlist,
- True,
- False,
- )
- elif op_selection_yaml_path is not None:
- selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
- else:
- selector = SelectiveBuilder.get_nop_selector()
- return selector
- def get_grouped_by_view_native_functions(
- native_functions: Sequence[NativeFunction],
- ) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
- def maybe_create_view_group(
- d: dict[ViewSchemaKind | SchemaKind, NativeFunction],
- ) -> list[NativeFunction | NativeFunctionsViewGroup]:
- funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
- if ViewSchemaKind.aliasing in d:
- view = d.pop(ViewSchemaKind.aliasing)
- view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
- view_copy = d.pop(SchemaKind.functional, None)
- funcs.append(
- NativeFunctionsViewGroup(
- view=view,
- view_copy=view_copy,
- view_inplace=view_inplace,
- )
- )
- # Take the remaining functions that weren't part of the view group
- # and emit them separately
- funcs.extend(d.values())
- return funcs
- grouped_by_views: dict[
- FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction]
- ] = defaultdict(dict)
- for f in native_functions:
- schema = f.func.view_signature()
- view_kind: ViewSchemaKind = f.view_schema_kind
- # We need to group up ops relevant to the same "view", consisting of:
- # view op (ViewSchemaKind.aliasing)
- # view_inplace op (ViewSchemaKind.aliasing_inplace)
- # view_copy op (SchemaKind.functional)
- if view_kind == ViewSchemaKind.non_aliasing:
- kind = f.func.kind()
- if kind in grouped_by_views[schema]:
- raise AssertionError(
- f"Duplicate schema kind {kind} in {grouped_by_views[schema].keys()}"
- )
- grouped_by_views[schema][kind] = f
- else:
- if view_kind in grouped_by_views[schema]:
- raise AssertionError(
- f"{view_kind} already in {grouped_by_views[schema].keys()}"
- )
- grouped_by_views[schema][view_kind] = f
- return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
- def get_grouped_native_functions(
- native_functions: Sequence[NativeFunction],
- ) -> Sequence[NativeFunction | NativeFunctionsGroup]:
- def flatten_pre_group(
- d: dict[SchemaKind, NativeFunction],
- ) -> Sequence[NativeFunction | NativeFunctionsGroup]:
- r = NativeFunctionsGroup.from_dict(d)
- if r is None:
- # Invariant: any NativeFunctions that are code-generated
- # should have been grouped into NativeFunctionsGroup objects
- if any("generated" in f.tags for f in d.values()):
- raise AssertionError(
- "Generated NativeFunctions should have been grouped into "
- f"NativeFunctionsGroup objects: {list(d.values())}"
- )
- return list(d.values())
- else:
- return [r]
- # TODO: how come ValuesView isn't a Sequence lol
- pre_grouped_native_functions = pre_group_native_functions(native_functions)
- return list(
- concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))
- )
- def get_ns_grouped_kernels(
- *,
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- backend_indices: dict[DispatchKey, BackendIndex],
- native_function_decl_gen: Callable[
- [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
- ] = dest.compute_native_function_declaration,
- ) -> dict[str, list[str]]:
- ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
- for f in grouped_native_functions:
- native_function_namespaces = set()
- dispatch_keys = set()
- for dispatch_key, backend_idx in backend_indices.items():
- backend_metadata = backend_idx.get_kernel(f)
- if backend_metadata:
- namespace = backend_metadata.cpp_namespace
- dispatch_keys.add(dispatch_key)
- native_function_namespaces.add(namespace)
- else:
- namespace = DEFAULT_KERNEL_NAMESPACE
- if len(native_function_namespaces) > 1:
- raise AssertionError(
- f"Codegen only supports one namespace per operator, "
- f"got {native_function_namespaces} from {dispatch_keys}"
- )
- ns_grouped_kernels[namespace].extend(
- native_function_decl_gen(f, backend_idx)
- )
- return ns_grouped_kernels
- def get_native_function_declarations_from_ns_grouped_kernels(
- *,
- ns_grouped_kernels: dict[str, list[str]],
- ) -> list[str]:
- declarations: list[str] = []
- newline = "\n"
- for namespace, kernels in ns_grouped_kernels.items():
- ns_helper = NamespaceHelper(
- namespace_str=namespace,
- entity_name="",
- max_level=4,
- )
- # Convert to a set first to remove duplicate kernel names. Backends are
- # allowed to repeat kernel names; only generate the declaration once!
- ordered_kernels = list(OrderedDict.fromkeys(kernels))
- declarations.extend(
- f"""
- {ns_helper.prologue}
- {newline.join(ordered_kernels)}
- {ns_helper.epilogue}
- """.split(newline)
- )
- return declarations
- # Return native function declarations grouped by their namespaces.
- def get_native_function_declarations(
- *,
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- backend_indices: dict[DispatchKey, BackendIndex],
- native_function_decl_gen: Callable[
- [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
- ] = dest.compute_native_function_declaration,
- ) -> list[str]:
- """
- Generate kernel declarations, in `NativeFunction(s).h`.
- :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
- :param backend_indices: kernel collections grouped by dispatch key.
- :param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
- :return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
- """
- ns_grouped_kernels = get_ns_grouped_kernels(
- grouped_native_functions=grouped_native_functions,
- backend_indices=backend_indices,
- native_function_decl_gen=native_function_decl_gen,
- )
- return get_native_function_declarations_from_ns_grouped_kernels(
- ns_grouped_kernels=ns_grouped_kernels
- )
- def get_kernel_namespace(
- *, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex
- ) -> str:
- backend_metadata = backend_idx.get_kernel(f)
- if backend_metadata and "::native" not in backend_metadata.cpp_namespace:
- func_name = (
- f.func.name if isinstance(f, NativeFunction) else f.functional.func.name
- )
- raise AssertionError(
- f"The kernel for function {func_name} "
- f"with dispatch key {backend_idx.dispatch_key} "
- f"has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
- )
- return (
- backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
- )
- # Return native function definitions grouped by dispatch key and custom namespace.
- # Used in RegisterDispatchKey.cpp and etc.
- def get_native_function_definitions(
- *,
- fm: FileManager,
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- dispatch_key: DispatchKey,
- backend_idx: BackendIndex,
- selector: SelectiveBuilder,
- rocm: bool,
- symint: bool,
- skip_dispatcher_op_registration: bool,
- gen_dispatch_helpers: bool,
- ) -> list[str]:
- definitions: list[str] = []
- ns_definitions: dict[str, list[str]] = defaultdict(list)
- anonymous_definitions: dict[str, list[str]] = defaultdict(list)
- registrations: dict[str, dict[str, list[str]]] = defaultdict(dict)
- newline = "\n"
- ns_gen = dest.RegisterDispatchKey(
- backend_idx,
- Target.NAMESPACED_DEFINITION,
- selector,
- rocm=rocm,
- symint=symint,
- class_method_name=None,
- skip_dispatcher_op_registration=skip_dispatcher_op_registration,
- )
- anonymous_gen = dest.RegisterDispatchKey(
- backend_idx,
- Target.ANONYMOUS_DEFINITION,
- selector,
- rocm=rocm,
- symint=symint,
- class_method_name=None,
- skip_dispatcher_op_registration=skip_dispatcher_op_registration,
- )
- reg_gen = dest.RegisterDispatchKey(
- backend_idx,
- Target.REGISTRATION,
- selector,
- rocm=rocm,
- symint=symint,
- class_method_name=None,
- skip_dispatcher_op_registration=skip_dispatcher_op_registration,
- )
- for f in grouped_native_functions:
- kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
- "::native", ""
- )
- ns_definitions[kernel_namespace].extend(
- ns_gen(f),
- )
- anonymous_definitions[kernel_namespace].extend(
- anonymous_gen(f),
- )
- namespace = (
- f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
- )
- if namespace not in registrations[kernel_namespace]:
- registrations[kernel_namespace] = defaultdict(list)
- registrations[kernel_namespace][namespace].extend(
- reg_gen(f),
- )
- for kernel_namespace in ns_definitions:
- if len(ns_definitions[kernel_namespace]) == 0:
- continue
- ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
- registration_body = ""
- for namespace in registrations[kernel_namespace]:
- if not registrations[kernel_namespace][namespace]:
- continue
- registration_body += f"""
- TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
- {newline.join(registrations[kernel_namespace][namespace])}
- }}"""
- definitions.extend(
- fm.substitute_with_template(
- "RegisterDispatchDefinitions.ini",
- lambda: {
- "ns_prologue": ns_helper.prologue,
- "ns_epilogue": ns_helper.epilogue,
- "dispatch_anonymous_definitions": anonymous_definitions[
- kernel_namespace
- ],
- "static_init_dispatch_registrations": ""
- if skip_dispatcher_op_registration
- else registration_body,
- "deferred_dispatch_registrations": "",
- "dispatch_namespace": dispatch_key.lower(),
- "dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
- },
- ).split(newline)
- )
- return definitions
- # Return native function declarations grouped by dispatch key and custom namespace.
- # Used in CPUFunctions_inl.h and etc.
- def get_namespaced_declaration(
- *,
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- dispatch_key: DispatchKey,
- backend_idx: BackendIndex,
- selector: SelectiveBuilder,
- rocm: bool,
- symint: bool,
- ) -> list[str]:
- declarations: list[str] = []
- ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
- newline = "\n"
- func = dest.RegisterDispatchKey(
- backend_idx,
- Target.NAMESPACED_DECLARATION,
- selector,
- rocm=rocm,
- class_method_name=None,
- skip_dispatcher_op_registration=False,
- symint=symint,
- )
- for f in grouped_native_functions:
- namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
- "native", dispatch_key.lower()
- )
- ns_grouped_kernels[namespace].extend(
- func(f),
- )
- for namespace, kernels in ns_grouped_kernels.items():
- if len(kernels) == 0:
- continue
- ns_helper = NamespaceHelper(
- namespace_str=namespace, entity_name="", max_level=3
- )
- ordered_kernels = list(OrderedDict.fromkeys(kernels))
- declarations.extend(
- f"""
- {ns_helper.prologue}
- {newline.join(ordered_kernels)}
- {ns_helper.epilogue}
- """.split(newline)
- )
- return declarations
- # Return native function schema registration code for aten and other namespaces.
- def get_native_function_schema_registrations(
- *,
- native_functions: Sequence[NativeFunction],
- schema_selector: SelectiveBuilder,
- ) -> tuple[list[str], str]:
- ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
- for native_function in native_functions:
- ns_native_functions[native_function.namespace].append(native_function)
- schema_registrations = ""
- aten_schema_registrations = []
- custom_namespace = None
- for namespace, funcs in ns_native_functions.items():
- schema_registrations_body = list(
- mapMaybe(RegisterSchema(schema_selector), funcs)
- )
- # NB: we have to separate aten namespace registration from other namespaces,
- # because in the template we hardcoded an operator for ATen already.
- if namespace == "aten":
- aten_schema_registrations = schema_registrations_body
- else:
- custom_namespace = namespace
- tab = "\t"
- # if the namespace is predefined, we should use define a library fragment
- # instead of a new library
- torch_library_macro = (
- "TORCH_LIBRARY_FRAGMENT"
- if namespace in FRAGMENT_NAMESPACES
- else "TORCH_LIBRARY"
- )
- schema_registrations += f"""
- {torch_library_macro}({custom_namespace}, m) {{
- {tab.join(schema_registrations_body)}
- }};"""
- return (aten_schema_registrations, schema_registrations)
- def gen_aggregated_headers(
- *,
- native_functions: Sequence[NativeFunction],
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- structured_native_functions: Sequence[NativeFunctionsGroup],
- static_dispatch_idx: list[BackendIndex],
- selector: SelectiveBuilder,
- backend_indices: dict[DispatchKey, BackendIndex],
- cpu_fm: FileManager,
- device_fms: dict[str, FileManager],
- functions_keys: set[DispatchKey],
- dispatch_keys: Sequence[DispatchKey],
- rocm: bool,
- ) -> None:
- # Buck doesn't support dynamic output files, so we aggregate all operator
- # headers into a single file
- cpu_fm.write(
- "NativeMetaFunctions.h",
- lambda: {
- "NativeMetaFunctions_includes": [],
- "NativeMetaFunctions_declarations": list(
- mapMaybe(compute_meta_function_declaration, structured_native_functions)
- ),
- },
- )
- method_native_functions = [
- fn for fn in native_functions if Variant.method in fn.variants
- ]
- non_method_native_functions = [
- fn for fn in native_functions if fn not in method_native_functions
- ]
- cpu_fm.write(
- "MethodOperators.h",
- lambda: {
- "MethodOperators_includes": [],
- "MethodOperators_declarations": list(
- mapMaybe(
- ComputeOperators(
- Target.DECLARATION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- method_native_functions,
- )
- ),
- },
- )
- cpu_fm.write(
- "Operators.h",
- lambda: {
- "Operators_includes": ["#include <ATen/MethodOperators.h>"],
- "Operators_declarations": list(
- mapMaybe(
- ComputeOperators(
- Target.DECLARATION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- non_method_native_functions,
- )
- ),
- },
- )
- cpu_fm.write(
- "Functions.h",
- lambda: {
- "static_dispatch_extra_headers": static_dispatch_extra_headers(
- static_dispatch_idx
- ),
- "Functions_includes": ["#include <ATen/Operators.h>"],
- "Functions_declarations": list(
- mapMaybe(
- ComputeFunction(),
- native_functions,
- )
- ),
- },
- )
- declarations = get_native_function_declarations(
- grouped_native_functions=grouped_native_functions,
- backend_indices=backend_indices,
- )
- cpu_fm.write(
- "NativeFunctions.h",
- lambda: {
- "NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"],
- "NativeFunctions_declarations": declarations,
- },
- )
- for dispatch_key in dispatch_keys:
- fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
- if dispatch_key in functions_keys:
- inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
- fm.write_with_template(
- f"{dispatch_key}Functions.h",
- "DispatchKeyFunctions.h",
- lambda: {
- "dispatch_key": str(dispatch_key),
- "inline_headers": inl_headers,
- },
- )
- fm.write_with_template(
- f"{dispatch_key}Functions_inl.h",
- "DispatchKeyFunctions_inl.h",
- lambda: {
- "DispatchKeyFunctions_inl_includes": [],
- "dispatch_namespace": dispatch_key.lower(),
- "dispatch_namespaced_declarations": get_namespaced_declaration(
- grouped_native_functions=grouped_native_functions,
- dispatch_key=dispatch_key,
- backend_idx=backend_indices[dispatch_key],
- selector=selector,
- rocm=rocm,
- symint=True,
- ),
- },
- )
- del fm
- def gen_per_operator_headers(
- *,
- native_functions: Sequence[NativeFunction],
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- static_dispatch_idx: list[BackendIndex],
- selector: SelectiveBuilder,
- backend_indices: dict[DispatchKey, BackendIndex],
- cpu_fm: FileManager,
- device_fms: dict[str, FileManager],
- ops_fm: FileManager,
- functions_keys: set[DispatchKey],
- dispatch_keys: Sequence[DispatchKey],
- rocm: bool,
- ) -> None:
- # For CMake builds, split operator declarations into separate headers in
- # the ATen/ops folder to split up header dependencies
- functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list)
- for fn in native_functions:
- functions_by_root_name[fn.root_name].append(fn)
- grouped_functions_by_root_name: dict[
- str, list[NativeFunction | NativeFunctionsGroup]
- ] = defaultdict(list)
- for group in grouped_native_functions:
- name = group.root_name
- grouped_functions_by_root_name[name].append(group)
- for name, functions in functions_by_root_name.items():
- ops_fm.write_with_template(
- f"{name}_ops.h",
- "Operator.h",
- lambda: {
- "declarations": list(
- mapMaybe(
- ComputeOperators(
- Target.DECLARATION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- functions,
- )
- ),
- },
- )
- ops_fm.write_with_template(
- f"{name}.h",
- "Function.h",
- lambda: {
- "static_dispatch_ops_headers": list(
- mapMaybe(
- lambda fn: static_dispatch_ops_header(
- fn, backend_index=static_dispatch_idx
- ),
- functions,
- )
- ),
- "operator_includes": f"#include <ATen/ops/{name}_ops.h>",
- "function_definitions": list(
- mapMaybe(
- ComputeFunction(),
- functions,
- )
- ),
- },
- )
- grouped_functions = grouped_functions_by_root_name.get(name, [])
- structured_functions = [
- fn
- for fn in grouped_functions
- if isinstance(fn, NativeFunctionsGroup) and fn.structured
- ]
- is_structured = len(structured_functions) > 0
- if is_structured:
- ops_fm.write_with_template(
- f"{name}_meta.h",
- "NativeMetaFunction.h",
- lambda: {
- "meta_function_declarations": list(
- mapMaybe(
- compute_meta_function_declaration, structured_functions
- )
- ),
- },
- )
- declarations = get_native_function_declarations(
- grouped_native_functions=grouped_functions,
- backend_indices=backend_indices,
- native_function_decl_gen=dest.compute_native_function_declaration,
- )
- ops_fm.write_with_template(
- f"{name}_native.h",
- "NativeFunction.h",
- lambda: {
- "extra_includes": (
- f"#include <ATen/ops/{name}_meta.h>" if is_structured else []
- ),
- "native_function_declarations": declarations,
- },
- )
- for category, suffix in [
- ("Functions", ""),
- ("Operators", "_ops"),
- ("NativeMetaFunctions", "_meta"),
- ("NativeFunctions", "_native"),
- ]:
- cpu_fm.write(
- f"{category}.h",
- lambda: {
- f"{category}_includes": [
- f"#include <ATen/ops/{name}{suffix}.h>"
- for name in sorted(functions_by_root_name.keys())
- ],
- f"{category}_declarations": [],
- },
- )
- for dispatch_key in dispatch_keys:
- if dispatch_key not in functions_keys:
- continue
- dispatch_namespace = dispatch_key.lower()
- dispatch_names = []
- for name, functions in functions_by_root_name.items():
- grouped_functions = grouped_functions_by_root_name.get(name, [])
- declarations = list(
- concatMap(
- dest.RegisterDispatchKey(
- backend_indices[dispatch_key],
- Target.NAMESPACED_DECLARATION,
- selector,
- rocm=rocm,
- symint=True,
- class_method_name=None,
- skip_dispatcher_op_registration=False,
- ),
- grouped_functions,
- )
- )
- if len(declarations) == 0:
- continue
- dispatch_names.append(name)
- ops_fm.write_with_template(
- f"{name}_{dispatch_namespace}_dispatch.h",
- "DispatchKeyFunction.h",
- lambda: {
- "dispatch_namespace": dispatch_namespace,
- "dispatch_namespaced_declarations": declarations,
- },
- )
- fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
- inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
- fm.write_with_template(
- f"{dispatch_key}Functions.h",
- "DispatchKeyFunctions.h",
- lambda: {
- "dispatch_key": str(dispatch_key),
- "inline_headers": inl_headers,
- },
- )
- fm.write_with_template(
- f"{dispatch_key}Functions_inl.h",
- "DispatchKeyFunctions_inl.h",
- lambda: {
- "dispatch_namespace": dispatch_namespace,
- "DispatchKeyFunctions_inl_includes": [
- f"#include <ATen/ops/{name}_{dispatch_namespace}_dispatch.h>"
- for name in sorted(dispatch_names)
- ],
- "dispatch_namespaced_declarations": [],
- },
- )
- del fm
- cpu_fm.write(
- "MethodOperators.h",
- lambda: {
- "MethodOperators_includes": sorted(
- f"#include <ATen/ops/{name}_ops.h>"
- for name, functions in functions_by_root_name.items()
- if any(Variant.method in fn.variants for fn in functions)
- ),
- "MethodOperators_declarations": [],
- },
- )
- def gen_headers(
- *,
- native_functions: Sequence[NativeFunction],
- valid_tags: set[str],
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- structured_native_functions: Sequence[NativeFunctionsGroup],
- static_dispatch_idx: list[BackendIndex],
- selector: SelectiveBuilder,
- backend_indices: dict[DispatchKey, BackendIndex],
- core_fm: FileManager,
- cpu_fm: FileManager,
- device_fms: dict[str, FileManager],
- ops_fm: FileManager,
- dispatch_keys: Sequence[DispatchKey],
- functions_keys: set[DispatchKey],
- rocm: bool,
- per_operator_headers: bool,
- ) -> None:
- if per_operator_headers:
- gen_per_operator_headers(
- native_functions=native_functions,
- grouped_native_functions=grouped_native_functions,
- static_dispatch_idx=static_dispatch_idx,
- selector=selector,
- backend_indices=backend_indices,
- cpu_fm=cpu_fm,
- device_fms=device_fms,
- ops_fm=ops_fm,
- dispatch_keys=dispatch_keys,
- functions_keys=functions_keys,
- rocm=rocm,
- )
- else:
- gen_aggregated_headers(
- native_functions=native_functions,
- grouped_native_functions=grouped_native_functions,
- structured_native_functions=structured_native_functions,
- static_dispatch_idx=static_dispatch_idx,
- selector=selector,
- backend_indices=backend_indices,
- cpu_fm=cpu_fm,
- device_fms=device_fms,
- dispatch_keys=dispatch_keys,
- functions_keys=functions_keys,
- rocm=rocm,
- )
- core_fm.write(
- "TensorBody.h",
- lambda: {
- "tensor_method_declarations": list(
- mapMaybe(
- ComputeTensorMethod(
- target=Target.DECLARATION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- native_functions,
- )
- ),
- "tensor_method_definitions": list(
- mapMaybe(
- ComputeTensorMethod(
- target=Target.DEFINITION,
- static_dispatch_backend_indices=static_dispatch_idx,
- ),
- native_functions,
- )
- ),
- },
- )
- cpu_fm.write(
- "RedispatchFunctions.h",
- lambda: {
- "function_redispatch_definitions": list(
- mapMaybe(ComputeRedispatchFunction(), native_functions)
- ),
- },
- )
- cpu_fm.write(
- "RegistrationDeclarations.h",
- lambda: {
- "registration_declarations": [
- compute_registration_declarations(f, backend_indices)
- for f in native_functions
- ],
- },
- )
- cpu_fm.write(
- "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
- )
- def gen_aten_interned_strings() -> dict[str, str]:
- attrs: set[str] = set() # All function argument names
- names = set() # All ATen function names
- for func in native_functions:
- names.add(str(func.func.name.name))
- # Some operators don't have a functional variant but we still create a
- # symbol without the underscore
- names.add(func.func.name.name.base)
- attrs.update(arg.name for arg in func.func.schema_order_arguments())
- # These are keywords in C++, so aren't valid symbol names
- # https://en.cppreference.com/w/cpp/language/operator_alternative
- names -= {
- "and",
- "and_eq",
- "bitand",
- "bitor",
- "compl",
- "not",
- "not_eq",
- "or",
- "or_eq",
- "xor",
- "xor_eq",
- }
- return {
- "aten_symbols": " \\\n".join(
- [f"_(aten, {name})" for name in sorted(names)]
- ),
- "attr_symbols": " \\\n".join(
- [f"_(attr, {name})" for name in sorted(attrs)]
- ),
- }
- core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
- def gen_tags_enum() -> dict[str, str]:
- return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
- core_fm.write("enum_tag.h", gen_tags_enum)
- def gen_source_files(
- *,
- native_functions: Sequence[NativeFunction],
- grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
- structured_native_functions: Sequence[NativeFunctionsGroup],
- view_groups: Sequence[NativeFunctionsViewGroup],
- selector: SelectiveBuilder,
- static_dispatch_idx: list[BackendIndex],
- backend_indices: dict[DispatchKey, BackendIndex],
- aoti_fm: FileManager,
- core_fm: FileManager,
- cpu_vec_fm: FileManager,
- cpu_fm: FileManager,
- device_fms: dict[str, FileManager],
- dispatch_keys: Sequence[DispatchKey],
- functions_keys: set[DispatchKey],
- rocm: bool,
- force_schema_registration: bool,
- per_operator_headers: bool,
- skip_dispatcher_op_registration: bool,
- update_aoti_c_shim: bool,
- aoti_backends: set[DispatchKey | None],
- extend_aoti_c_shim: bool,
- ) -> None:
- extra_cuda_headers = """\
- #include <c10/cuda/CUDAGuard.h>
- #include <ATen/cuda/ATenCUDAGeneral.h>
- #include <ATen/cuda/CUDADevice.h>
- #include <ATen/cuda/CUDAContext.h>"""
- if rocm:
- extra_cuda_headers = """\
- #include <c10/hip/HIPGuard.h>
- #include <ATen/hip/ATenHIPGeneral.h>
- #include <ATen/hip/HIPDevice.h>
- #include <ATen/hip/HIPContext.h>"""
- for dispatch_key in dispatch_keys:
- fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
- if per_operator_headers:
- def operator_headers() -> list[str]:
- headers = []
- for g in grouped_native_functions:
- is_registered = False
- if backend_index.has_kernel(g):
- is_registered = True
- # The above has_kernel test on a group will only test for
- # the existence of out dispatch, because that's how
- # structured kernels work. But sometimes functions can be
- # grouped but not be structured, and then you need to check
- # each individual piece, as they may have manual dispatch
- # entries.
- elif isinstance(g, NativeFunctionsGroup) and any(
- backend_index.has_kernel(fn) for fn in g.functions()
- ):
- is_registered = True
- # TODO: this condition is a bit questionable
- # (It has to do with the fact that structured kernels get generated kernels
- # to the Meta + CompositeExplicitAutogradNonFunctional keys).
- elif g.structured and dispatch_key in (
- DispatchKey.Meta,
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- ):
- is_registered = True
- if not is_registered:
- continue
- headers.append(f"#include <ATen/ops/{g.root_name}_native.h>")
- if (
- dispatch_key
- == DispatchKey.CompositeExplicitAutogradNonFunctional
- ):
- headers.append(f"#include <ATen/ops/{g.root_name}.h>")
- if dispatch_key in functions_keys:
- headers.append(
- f"#include <ATen/ops/{g.root_name}_{dispatch_namespace}_dispatch.h>"
- )
- return sorted(set(headers))
- else:
- def operator_headers() -> list[str]:
- headers = ["#include <ATen/NativeFunctions.h>"]
- if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
- headers.append("#include <ATen/Functions.h>")
- if dispatch_key in functions_keys:
- headers.append(f"#include <ATen/{dispatch_key!s}Functions.h>")
- return headers
- backend_index = backend_indices[dispatch_key]
- ns_grouped_native_functions = defaultdict(list)
- for grouped_native_function in grouped_native_functions:
- namespace = (
- grouped_native_function.namespace
- if isinstance(grouped_native_function, NativeFunction)
- else grouped_native_function.functional.namespace
- )
- ns_grouped_native_functions[namespace].append(grouped_native_function)
- dispatch_namespace = str(dispatch_key).lower()
- # CompositeImplicitAutogradNestdTensor does not currently user the helpers generated
- # compilation will fail when `-Werror=unused-function` flag is set
- gen_dispatch_helpers: bool = (
- dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
- )
- register_dispatch_key_base_env = {
- "extra_cuda_headers": extra_cuda_headers
- if is_cuda_dispatch_key(dispatch_key)
- else "",
- "external_backend_headers": "",
- "dispatch_headers": dest.gen_registration_headers(
- backend_index, per_operator_headers, rocm
- ),
- # ops_headers *could* be sharded, but doesn't seem necessary?
- "ops_headers": operator_headers(),
- "dispatch_helpers": (
- dest.gen_registration_helpers(backend_index)
- if gen_dispatch_helpers
- else []
- ),
- }
- def register_dispatch_key_env_callable(
- gnf: NativeFunction | NativeFunctionsGroup,
- ) -> dict[str, list[str]]:
- return {
- "dispatch_definitions": get_native_function_definitions(
- fm=fm, # noqa: F821
- grouped_native_functions=[gnf],
- dispatch_key=dispatch_key,
- backend_idx=backend_index,
- selector=selector,
- rocm=rocm,
- symint=True,
- skip_dispatcher_op_registration=skip_dispatcher_op_registration,
- gen_dispatch_helpers=gen_dispatch_helpers,
- )
- }
- fm.write_sharded_with_template(
- f"Register{dispatch_key}.cpp",
- "RegisterDispatchKey.cpp",
- grouped_native_functions,
- key_fn=lambda x: x.root_name,
- env_callable=register_dispatch_key_env_callable,
- num_shards=4 if dispatch_key == DispatchKey.CPU else 1,
- base_env=register_dispatch_key_base_env,
- sharded_keys={"dispatch_definitions"},
- )
- for g in structured_native_functions:
- if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key):
- continue
- name = g.functional.func.name.name
- if dispatch_key is DispatchKey.CPU:
- if fm is not cpu_fm:
- raise AssertionError("Expected fm to be cpu_fm for DispatchKey.CPU")
- fm.write_with_template(
- f"UfuncCPU_{name}.cpp",
- "UfuncCPU.cpp",
- lambda: {
- "meta_declaration": compute_meta_function_declaration(g),
- "native_declaration": dest.compute_native_function_declaration(
- g, backend_indices[dispatch_key]
- ),
- "native_definitions": dest.compute_ufunc_cpu(g),
- },
- )
- cpu_vec_fm.write_with_template(
- f"UfuncCPUKernel_{name}.cpp",
- "UfuncCPUKernel.cpp",
- lambda: {
- "name": name,
- "native_definitions": dest.compute_ufunc_cpu_kernel(g),
- },
- )
- elif dispatch_key is DispatchKey.CUDA:
- cuda_headers = "#include <ATen/native/cuda/Loops.cuh>"
- if rocm:
- cuda_headers = "#include <ATen/native/hip/Loops.cuh>"
- fm.write_with_template(
- f"UfuncCUDA_{name}.cu",
- "UfuncCUDA.cu",
- lambda: {
- "name": name,
- "cuda_headers": cuda_headers,
- "meta_declaration": compute_meta_function_declaration(g),
- "native_declaration": dest.compute_native_function_declaration(
- g, backend_indices[dispatch_key]
- ),
- "native_definitions": dest.compute_ufunc_cuda(g),
- },
- )
- else:
- raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
- del fm
- gen_aoti_c_shim_files(
- aoti_fm=aoti_fm,
- aoti_backends=aoti_backends,
- native_functions=native_functions,
- backend_indices=backend_indices,
- structured_native_functions=structured_native_functions,
- extra_cuda_headers=extra_cuda_headers,
- update_aoti_c_shim=update_aoti_c_shim,
- extend_aoti_c_shim=extend_aoti_c_shim,
- )
- # BackendSelect is generated specially
- def gen_backend_select() -> dict[str, list[str]]:
- relevant_fns = [
- fn for fn in native_functions if needs_backend_select(fn, selector)
- ]
- return {
- "ops_headers": [
- f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in relevant_fns
- ],
- "backend_select_method_definitions": list(
- mapMaybe(
- ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns
- )
- ),
- "backend_select_function_registrations": list(
- mapMaybe(
- ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns
- )
- ),
- }
- cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select)
- schema_selector = selector
- if force_schema_registration:
- schema_selector = SelectiveBuilder.get_nop_selector()
- (
- aten_schema_registrations,
- schema_registrations,
- ) = get_native_function_schema_registrations(
- native_functions=native_functions, schema_selector=schema_selector
- )
- cpu_fm.write(
- "RegisterSchema.cpp",
- lambda: {
- "aten_schema_registrations": []
- if skip_dispatcher_op_registration
- else aten_schema_registrations,
- "schema_registrations": []
- if skip_dispatcher_op_registration
- else schema_registrations,
- },
- )
- def key_func(
- fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
- ) -> str:
- return fn.root_name
- cpu_fm.write_sharded(
- "Operators.cpp",
- native_functions,
- key_fn=key_func,
- env_callable=lambda fn: {
- "operator_headers": [f"#include <ATen/ops/{fn.root_name}.h>"],
- "definitions": [
- ComputeOperators(
- Target.DEFINITION,
- static_dispatch_backend_indices=static_dispatch_idx,
- )(fn)
- ],
- },
- base_env={
- "static_dispatch_extra_headers": static_dispatch_extra_headers(
- static_dispatch_idx
- ),
- },
- num_shards=5,
- sharded_keys={
- "operator_headers",
- "definitions",
- "static_dispatch_extra_headers",
- },
- )
- cpu_fm.write("Functions.cpp", dict)
- core_fm.write("TensorMethods.cpp", dict)
- core_fm.write(
- "ATenOpList.cpp",
- lambda: {
- "aten_ops": list(mapMaybe(compute_aten_op, native_functions)),
- },
- )
- def gen_op_headers(
- g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
- ) -> list[str]:
- if isinstance(g, NativeFunctionsViewGroup):
- # view ops always get a functionalization kernel
- headers = [
- f"#include <ATen/ops/{g.view.root_name}_native.h>",
- f"#include <ATen/ops/{g.view.root_name}_ops.h>",
- ]
- if g.view_copy is not None:
- headers += [
- f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
- f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
- ]
- return headers
- elif isinstance(g, NativeFunctionsGroup):
- headers = [
- f"#include <ATen/ops/{g.functional.root_name}_native.h>",
- f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
- f"#include <ATen/ops/{g.out.root_name}_native.h>",
- f"#include <ATen/ops/{g.out.root_name}_ops.h>",
- ]
- if g.inplace is not None:
- headers += [
- f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
- f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
- ]
- if g.mutable is not None:
- headers += [
- f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
- f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
- ]
- return headers
- else:
- return [
- f"#include <ATen/ops/{g.root_name}_native.h>",
- f"#include <ATen/ops/{g.root_name}_ops.h>",
- ]
- def functionalization_env_callable(
- g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
- ) -> dict[str, list[str]]:
- return {
- "ops_headers": gen_op_headers(g),
- "func_definitions": gen_functionalization_definition(
- selector,
- g,
- ),
- "func_registrations": gen_functionalization_registration(
- selector,
- g,
- backend_indices[DispatchKey.CompositeImplicitAutograd],
- ),
- }
- all_groups: list[
- NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup
- ] = list(structured_native_functions) + list(
- view_groups # type: ignore[assignment, arg-type, operator]
- )
- # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
- # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
- # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
- # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
- # Although this could go away long-term if we add a dedicated dispatch key for decompositions.
- structured_map: dict[OperatorName, NativeFunction] = {
- f.func.name: f
- for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
- }
- view_map: dict[OperatorName, NativeFunction] = {
- f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
- }
- all_groups.extend(
- f
- for f in native_functions
- if f.func.name not in structured_map and f.func.name not in view_map
- )
- cpu_fm.write_sharded(
- "RegisterFunctionalization.cpp",
- all_groups,
- key_fn=key_func,
- env_callable=functionalization_env_callable,
- num_shards=4,
- sharded_keys={
- "ops_headers",
- "func_definitions",
- "func_registrations",
- "func_add_back_views_definitions",
- "func_add_back_views_registrations",
- },
- )
- cpu_fm.write(
- "FunctionalInverses.h",
- lambda: {
- "view_inverse_declarations": list(
- mapMaybe(
- lambda g: gen_functionalization_view_inverse_declaration(
- selector, g
- ),
- view_groups,
- )
- )
- },
- )
- cpu_fm.write(
- "ViewMetaClasses.h",
- lambda: {
- "view_meta_declarations": list(
- concatMap(
- lambda g: gen_functionalization_view_meta_classes_decl(selector, g),
- view_groups,
- )
- )
- },
- )
- cpu_fm.write(
- "ViewMetaClasses.cpp",
- lambda: {
- "view_meta_implementations": list(
- concatMap(
- lambda g: gen_functionalization_view_meta_classes_impl(selector, g),
- view_groups,
- )
- ),
- "op_headers": list(concatMap(gen_op_headers, view_groups)),
- },
- )
- # Note [view_copy NativeFunctions]
- # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
- # needs to have a corresponding non-aliasing {view}_copy variant.
- # Backends that use functionalization and don't know how to handle aliasing ops
- # are expected to implement kernels for these {view}_copy kernels instead.
- # The code for {view}_copy operators in core is pretty boilerplate-heavy however,
- # so we codegen the following:
- # (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator.
- # These are never explicitly invoked by the functionalization pass,
- # but they could theoretically be called from user code (I added these kernels for completeness,
- # since the ops are part of the public API).
- # (2) A derivative formula for every {view}_copy operator
- # {view}_copy operators can reuse the same derivative formulas as their {view} op counterparts,
- # so rather than stamping all of the entries out in derivatives.yaml,
- # we codegen them in.
- # This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry.
- cpu_fm.write(
- "CompositeViewCopyKernels.cpp",
- lambda: {
- "ops_headers": [
- "\n".join(
- f"#include <ATen/ops/{f.root_name}_ops.h>\n"
- # NB: this include is important as it ensures we
- # set the visibility on generated view_copy kernels
- # correctly
- f"#include <ATen/ops/{f.root_name}_native.h>"
- for f in (
- [g.view] if g.view_copy is None else [g.view, g.view_copy]
- )
- )
- for g in view_groups
- ]
- + [
- "\n".join(
- f"#include <ATen/ops/{f.root_name}_ops.h>\n"
- # NB: this include is also important for correct visibility
- f"#include <ATen/ops/{f.root_name}_native.h>"
- for f in [g.inplace, g.mutable, g.functional]
- if f is not None and "generated" not in f.tags
- )
- for g in structured_native_functions
- ],
- "CompositeViewCopyKernel_Definitions": list(
- mapMaybe(
- GenCompositeViewCopyKernel(
- backend_indices[
- DispatchKey.CompositeExplicitAutogradNonFunctional
- ]
- ),
- view_groups,
- )
- ),
- "GeneratedCompositeFunctional_Definitions": list(
- mapMaybe(
- gen_composite_functional_kernel,
- structured_native_functions,
- )
- ),
- "GeneratedCompositeOut_Definitions": list(
- mapMaybe(
- gen_composite_out_kernel,
- structured_native_functions,
- )
- ),
- },
- )
- def gen_declarations_yaml(
- cpu_fm: FileManager, native_functions: Sequence[NativeFunction]
- ) -> None:
- cpu_fm.write(
- "Declarations.yaml",
- lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]),
- )
- def get_torchgen_root() -> Path:
- """
- If you're depending on torchgen out-of-tree, you can use the root to figure
- out the path to native_functions.yaml
- """
- return Path(__file__).parent.resolve()
- def main() -> None:
- parser = argparse.ArgumentParser(description="Generate ATen source files")
- parser.add_argument(
- "-s",
- "--source-path",
- help="path to source directory for ATen",
- default="aten/src/ATen",
- )
- parser.add_argument(
- "-o",
- "--output-dependencies",
- help="output a list of dependencies into the given file and exit",
- )
- parser.add_argument(
- "--dry-run",
- action="store_true",
- help="run without writing any files (still updates outputs)",
- )
- parser.add_argument(
- "--per-operator-headers",
- action="store_true",
- help="generate separate headers per operator in ATen/ops",
- )
- parser.add_argument(
- "-d",
- "--install-dir",
- "--install_dir",
- help="output directory",
- default="build/aten/src/ATen",
- )
- parser.add_argument(
- "--aoti-install-dir",
- "--aoti_install_dir",
- help="output directory for AOTInductor shim",
- default="torch/csrc/inductor/aoti_torch/generated",
- )
- parser.add_argument(
- "--rocm",
- action="store_true",
- help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
- )
- parser.add_argument(
- "--mps",
- action="store_true",
- help="Generate MPS registration code when set",
- )
- parser.add_argument(
- "--xpu",
- action="store_true",
- help="Generate XPU registration code when set",
- )
- parser.add_argument(
- "--mtia",
- action="store_true",
- help="Generate MTIA registration code when set",
- )
- # TODO: --op-registration-whitelist will be removed when all call-sites
- # for gen.py are moved over to using the operator YAML file for mobile
- # custom build.
- parser.add_argument(
- "--op-registration-whitelist",
- "--op_registration_whitelist",
- nargs="*",
- help="filter op registrations by the whitelist (if set); "
- "each item is `namespace`::`operator name` without overload name; "
- "e.g.: aten::empty aten::conv2d ...",
- )
- parser.add_argument(
- "--op-selection-yaml-path",
- "--op_selection_yaml_path",
- help="Provide a path to the operator selection (for custom build) YAML "
- "that contains the information about the set of selected operators "
- "and their categories (training, ...). Each operator is either a "
- "full operator name with overload or just a bare operator name. "
- "The operator names also contain the namespace prefix (e.g. aten::)",
- )
- parser.add_argument(
- "--backend-whitelist",
- "--backend_whitelist",
- nargs="*",
- help="filter dispatch backend by the whitelist (if set), "
- "e.g.: CPU CUDA QuantizedCPU ...",
- )
- parser.add_argument(
- "--static-dispatch-backend",
- "--static_dispatch_backend",
- nargs="*",
- help="generate static dispatch code for the specific backend (if set)",
- )
- parser.add_argument(
- "--skip-dispatcher-op-registration",
- "--skip_dispatcher_op_registration",
- action="store_true",
- help="Avoid registering operators into the dispatcher.",
- )
- parser.add_argument(
- "--force-schema-registration",
- "--force_schema_registration",
- action="store_true",
- help="force it to generate schema-only registrations for all ops, including"
- "those that are not listed on --op-registration-whitelist",
- )
- parser.add_argument(
- "--generate",
- type=str,
- nargs="*",
- choices=["headers", "sources", "declarations_yaml"],
- default=["headers", "sources", "declarations_yaml"],
- help="Generate only a subset of files",
- )
- parser.add_argument(
- "--update-aoti-c-shim",
- action="store_true",
- help="Update AOTInductor C shim after adding an entry to inductor_fallback_ops in torchgen/aoti/fallback_ops.py. "
- "WARNING: Do not use this unless you are sure what you are doing!!!",
- )
- parser.add_argument(
- "--extend-aoti-c-shim",
- action="store_true",
- help="This Flag indicates the generation of c shims for out-of-tree ATen ops,"
- "which is an extension to the In-tree ATen op c shims. This flag needs to be combined with"
- "---source-path=<out-of-tree native_functions.yaml>"
- "--aoti-install-dir=<in-tree aoti_install_dir>/extend"
- " default is torch/csrc/inductor/aoti_torch/generated/extend"
- "WARNING: Do not use this unless you are sure what you are doing!!!",
- )
- options = parser.parse_args()
- selector = get_custom_build_selector(
- options.op_registration_whitelist,
- options.op_selection_yaml_path,
- )
- native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
- tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
- from torchgen.model import dispatch_keys
- # Only a limited set of dispatch keys get CPUFunctions.h headers generated
- # for them; this is the set
- functions_keys = {
- DispatchKey.CPU,
- DispatchKey.CUDA,
- DispatchKey.CompositeImplicitAutograd,
- DispatchKey.CompositeImplicitAutogradNestedTensor,
- DispatchKey.CompositeExplicitAutograd,
- DispatchKey.CompositeExplicitAutogradNonFunctional,
- DispatchKey.Meta,
- DispatchKey.MTIA,
- }
- aoti_backends = {
- DispatchKey.CPU,
- DispatchKey.CUDA,
- # None will generate the aten shim based on aten_shimified_ops
- # which does not bypass the dispatcher
- None,
- }
- # TODO: stop generating CUDA kernels for non-CUDA builds
- ignore_keys = set()
- MPS_KEYS = {DispatchKey.MPS, DispatchKey.SparseMPS, DispatchKey.SparseCsrMPS}
- if options.mps or options.update_aoti_c_shim:
- functions_keys.update(MPS_KEYS)
- aoti_backends.add(DispatchKey.MPS)
- else:
- ignore_keys.update(MPS_KEYS)
- dispatch_keys[:] = [k for k in dispatch_keys if k not in MPS_KEYS]
- if options.xpu or options.update_aoti_c_shim:
- functions_keys.add(DispatchKey.XPU)
- aoti_backends.add(DispatchKey.XPU)
- else:
- ignore_keys.add(DispatchKey.XPU)
- if DispatchKey.XPU in dispatch_keys:
- del dispatch_keys[dispatch_keys.index(DispatchKey.XPU)]
- if not options.mtia:
- ignore_keys.add(DispatchKey.MTIA)
- if DispatchKey.MTIA in dispatch_keys:
- del dispatch_keys[dispatch_keys.index(DispatchKey.MTIA)]
- if options.backend_whitelist:
- dispatch_keys = [
- k
- for k in dispatch_keys
- if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
- ]
- parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
- valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
- native_functions, backend_indices = (
- parsed_yaml.native_functions,
- parsed_yaml.backend_indices,
- )
- grouped_native_functions = get_grouped_native_functions(native_functions)
- structured_native_functions = [
- g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
- ]
- native_functions_with_view_groups = get_grouped_by_view_native_functions(
- native_functions
- )
- view_groups = [
- g
- for g in native_functions_with_view_groups
- if isinstance(g, NativeFunctionsViewGroup)
- ]
- # NB: It is mandatory to NOT use os.path.join here, as the install directory
- # will eventually be ingested by cmake, which does not respect Windows style
- # path slashes. If you switch this to use os.path.join, you'll get an error
- # like:
- #
- # Syntax error in cmake code when parsing string
- #
- # C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
- #
- # Invalid character escape '\c'.
- core_install_dir = f"{options.install_dir}/core"
- Path(core_install_dir).mkdir(parents=True, exist_ok=True)
- ops_install_dir = f"{options.install_dir}/ops"
- Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
- aoti_install_dir = f"{options.aoti_install_dir}"
- Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
- core_fm = make_file_manager(options=options, install_dir=core_install_dir)
- cpu_fm = make_file_manager(options=options)
- cpu_vec_fm = make_file_manager(options=options)
- cuda_fm = make_file_manager(options=options)
- ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
- aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir)
- device_fms = {"cuda": cuda_fm}
- if options.xpu:
- device_fms["xpu"] = make_file_manager(options=options)
- static_dispatch_idx: list[BackendIndex] = []
- if options.static_dispatch_backend:
- static_dispatch_idx = [
- backend_indices[DispatchKey.parse(key)]
- for key in options.static_dispatch_backend
- ]
- for key in options.static_dispatch_backend:
- dp_key = DispatchKey.parse(key)
- if dp_key not in functions_keys:
- functions_keys.add(dp_key)
- if "sources" in options.generate:
- gen_source_files(
- native_functions=native_functions,
- grouped_native_functions=grouped_native_functions,
- structured_native_functions=structured_native_functions,
- view_groups=view_groups,
- selector=selector,
- static_dispatch_idx=static_dispatch_idx,
- backend_indices=backend_indices,
- aoti_fm=aoti_fm,
- core_fm=core_fm,
- cpu_vec_fm=cpu_vec_fm,
- cpu_fm=cpu_fm,
- device_fms=device_fms,
- dispatch_keys=dispatch_keys,
- functions_keys=functions_keys,
- rocm=options.rocm,
- force_schema_registration=options.force_schema_registration,
- per_operator_headers=options.per_operator_headers,
- skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
- update_aoti_c_shim=options.update_aoti_c_shim,
- aoti_backends=aoti_backends,
- extend_aoti_c_shim=options.extend_aoti_c_shim,
- )
- if "headers" in options.generate:
- gen_headers(
- native_functions=native_functions,
- valid_tags=valid_tags,
- grouped_native_functions=grouped_native_functions,
- structured_native_functions=structured_native_functions,
- static_dispatch_idx=static_dispatch_idx,
- selector=selector,
- backend_indices=backend_indices,
- core_fm=core_fm,
- cpu_fm=cpu_fm,
- device_fms=device_fms,
- ops_fm=ops_fm,
- dispatch_keys=dispatch_keys,
- functions_keys=functions_keys,
- rocm=options.rocm,
- per_operator_headers=options.per_operator_headers,
- )
- if "declarations_yaml" in options.generate:
- gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
- if options.output_dependencies:
- depfile_path = Path(options.output_dependencies).resolve()
- depfile_name = depfile_path.name
- depfile_stem = depfile_path.stem
- for fm, prefix in [
- (cpu_fm, ""),
- (cpu_vec_fm, "cpu_vec_"),
- (core_fm, "core_"),
- (ops_fm, "ops_"),
- ] + [(device_fm, f"{device}_") for device, device_fm in device_fms.items()]:
- varname = prefix + depfile_stem
- path = depfile_path.parent / (prefix + depfile_name)
- fm.write_outputs(varname, str(path))
- if __name__ == "__main__":
- main()
|