gen.py 113 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067
  1. from __future__ import annotations
  2. import argparse
  3. import functools
  4. import json
  5. import keyword
  6. import os
  7. from collections import defaultdict, namedtuple, OrderedDict
  8. from dataclasses import dataclass, field
  9. from pathlib import Path
  10. from typing import Any, Literal, TYPE_CHECKING, TypeVar
  11. from typing_extensions import assert_never
  12. import yaml
  13. import torchgen.api.dispatcher as dispatcher
  14. import torchgen.api.meta as meta
  15. import torchgen.api.native as native
  16. import torchgen.api.structured as structured
  17. import torchgen.dest as dest
  18. from torchgen.api import cpp
  19. from torchgen.api.translate import translate
  20. from torchgen.api.types import (
  21. Binding,
  22. CppSignature,
  23. CppSignatureGroup,
  24. DispatcherSignature,
  25. NamedCType,
  26. NativeSignature,
  27. SpecialArgName,
  28. )
  29. from torchgen.context import (
  30. method_with_native_function,
  31. native_function_manager,
  32. with_native_function,
  33. with_native_function_and_indices,
  34. )
  35. from torchgen.gen_aoti_c_shim import (
  36. gen_aoti_c_shim_files,
  37. gen_static_dispatch_backend_call_signature,
  38. )
  39. from torchgen.gen_functionalization_type import (
  40. gen_functionalization_definition,
  41. gen_functionalization_registration,
  42. gen_functionalization_view_inverse_declaration,
  43. gen_functionalization_view_meta_classes_decl,
  44. gen_functionalization_view_meta_classes_impl,
  45. GenCompositeViewCopyKernel,
  46. )
  47. from torchgen.gen_vmap_plumbing import gen_all_vmap_plumbing
  48. from torchgen.model import (
  49. Argument,
  50. BackendIndex,
  51. BackendMetadata,
  52. BaseOperatorName,
  53. DEFAULT_KERNEL_NAMESPACE,
  54. dispatch_device_map,
  55. DispatchKey,
  56. FRAGMENT_NAMESPACES,
  57. FunctionSchema,
  58. is_cuda_dispatch_key,
  59. is_generic_dispatch_key,
  60. is_ufunc_dispatch_key,
  61. is_xpu_dispatch_key,
  62. Location,
  63. NativeFunction,
  64. NativeFunctionsGroup,
  65. NativeFunctionsViewGroup,
  66. OperatorName,
  67. OptionalType,
  68. SchemaKind,
  69. SelfArgument,
  70. STRUCTURED_DISPATCH_KEYS,
  71. TensorOptionsArguments,
  72. Type,
  73. Variant,
  74. ViewSchemaKind,
  75. )
  76. from torchgen.native_function_generation import (
  77. add_generated_native_functions,
  78. gen_composite_functional_kernel,
  79. gen_composite_out_kernel,
  80. pre_group_native_functions,
  81. )
  82. from torchgen.selective_build.selector import SelectiveBuilder
  83. from torchgen.utils import (
  84. concatMap,
  85. context,
  86. FileManager,
  87. make_file_manager,
  88. mapMaybe,
  89. NamespaceHelper,
  90. Target,
  91. )
  92. from torchgen.yaml_utils import YamlDumper, YamlLoader
  93. if TYPE_CHECKING:
  94. from collections.abc import Callable, Sequence
  95. T = TypeVar("T")
  96. # Welcome to the ATen code generator v2! The ATen code generator is
  97. # responsible for parsing native_functions.yaml and then generating
  98. # various generated files (e.g., TypeDefault.cpp) based on the operators
  99. # defined in this file. This means that the code generator knows how to
  100. # parse function schema, and then translate this into various C++ types
  101. # and boilerplate code.
  102. #
  103. # Some things to know about this file when you modify it:
  104. #
  105. # - This file has STRICT mypy typechecking. Typecheck it with
  106. # `mypy --config mypy-strict.ini` in the root source directory
  107. #
  108. # - Most of the heavy lifting lives in external modules:
  109. # - 'model' has the data model for native_functions.yaml. The classes
  110. # in those file represent what you see when you look at
  111. # a native_functions.yaml
  112. # - 'api' has conversions for how to translate JIT schema into
  113. # the various C++ APIs that the codegen interacts with. There
  114. # are in fact THREE different C++ APIs: the public C++ API,
  115. # the dispatcher API, and the legacy dispatcher API. See each
  116. # of these respective files for more information
  117. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  118. #
  119. # HELPER FUNCTIONS
  120. #
  121. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  122. # A custom loader for YAML to let us also keep track of line numbers
  123. # of each entry in the YAML file
  124. class LineLoader(YamlLoader):
  125. def construct_mapping(self, node, deep=False): # type: ignore[no-untyped-def]
  126. mapping = super().construct_mapping(node, deep=deep) # type: ignore[no-untyped-call]
  127. # Add 1 so line numbering starts at 1
  128. mapping["__line__"] = node.start_mark.line + 1
  129. return mapping
  130. # Parse native_functions.yaml into a sequence of NativeFunctions and Backend Indices.
  131. ParsedYaml = namedtuple("ParsedYaml", ["native_functions", "backend_indices"])
  132. _GLOBAL_PARSE_NATIVE_YAML_CACHE: dict[str, ParsedYaml] = {}
  133. _GLOBAL_PARSE_TAGS_YAML_CACHE: dict[str, set[str]] = {}
  134. def file_manager_from_dispatch_key(
  135. dispatch_key: DispatchKey,
  136. device_fms: dict[str, FileManager],
  137. default_fm: FileManager,
  138. ) -> FileManager:
  139. fm = device_fms.get(
  140. next(
  141. (
  142. device
  143. for check, device in dispatch_device_map.items()
  144. if check(dispatch_key)
  145. ),
  146. "",
  147. ),
  148. default_fm,
  149. )
  150. return fm
  151. def parse_native_yaml_struct(
  152. es: object,
  153. valid_tags: set[str],
  154. ignore_keys: set[DispatchKey] | None = None,
  155. path: str = "<stdin>",
  156. skip_native_fns_gen: bool = False,
  157. ) -> ParsedYaml:
  158. if not isinstance(es, list):
  159. raise AssertionError(f"Expected 'es' to be a list, but got {type(es)}")
  160. rs: list[NativeFunction] = []
  161. bs: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = defaultdict(dict)
  162. for e in es:
  163. if not isinstance(e, dict):
  164. raise AssertionError(f"Expected to be dict: {e}")
  165. if not isinstance(e.get("__line__"), int):
  166. raise AssertionError(f"Expected '__line__' to be int: {e}")
  167. loc = Location(path, e["__line__"])
  168. funcs = e.get("func")
  169. if funcs is None:
  170. raise AssertionError(f"Missed 'func' in {e}")
  171. with context(lambda: f"in {loc}:\n {funcs}"):
  172. func, m = NativeFunction.from_yaml(e, loc, valid_tags, ignore_keys)
  173. rs.append(func)
  174. BackendIndex.grow_index(bs, m)
  175. error_check_native_functions(rs)
  176. # Default dict is to prevent the codegen from barfing when we have a dispatch key that has no kernels yet.
  177. indices: dict[DispatchKey, BackendIndex] = defaultdict(
  178. lambda: BackendIndex(
  179. dispatch_key=DispatchKey.Undefined,
  180. use_out_as_primary=True,
  181. external=False,
  182. device_guard=False,
  183. # I'm actually not sure about this; undefined could be hit on
  184. # empty TensorList, hypothetically that could have sizes in it
  185. index={},
  186. )
  187. )
  188. if not skip_native_fns_gen:
  189. add_generated_native_functions(rs, bs)
  190. for k, v in bs.items():
  191. # All structured in-tree operators are implemented in terms of their out operator.
  192. indices[k] = BackendIndex(
  193. dispatch_key=k,
  194. use_out_as_primary=True,
  195. external=False,
  196. # Only cuda-like devices in tree require device guards
  197. device_guard=is_cuda_dispatch_key(k) or is_xpu_dispatch_key(k),
  198. index=v,
  199. )
  200. return ParsedYaml(rs, indices)
  201. def parse_tags_yaml_struct(es: object, path: str = "<stdin>") -> set[str]:
  202. if not isinstance(es, list):
  203. raise AssertionError(f"Expected 'es' to be a list, but got {type(es)}")
  204. rs: set[str] = set()
  205. for e in es:
  206. if not isinstance(e.get("__line__"), int):
  207. raise AssertionError(f"Expected '__line__' to be int: {e}")
  208. loc = Location(path, e["__line__"])
  209. tags = e.get("tag")
  210. with context(lambda: f"in {loc}:\n {tags}"):
  211. e_i = e.copy()
  212. name = e_i.pop("tag")
  213. desc = e_i.pop("desc", "")
  214. # ensure that each tag has a non-empty description
  215. if desc == "":
  216. raise AssertionError(f"Tag '{name}' must have a non-empty description")
  217. rs.add(name)
  218. return rs
  219. @functools.cache
  220. def parse_tags_yaml(path: str) -> set[str]:
  221. global _GLOBAL_PARSE_TAGS_YAML_CACHE
  222. if path not in _GLOBAL_PARSE_TAGS_YAML_CACHE:
  223. with open(path) as f:
  224. es = yaml.load(f, Loader=LineLoader)
  225. _GLOBAL_PARSE_TAGS_YAML_CACHE[path] = parse_tags_yaml_struct(es, path=path)
  226. return _GLOBAL_PARSE_TAGS_YAML_CACHE[path]
  227. def parse_native_yaml(
  228. path: str,
  229. tags_yaml_path: str,
  230. ignore_keys: set[DispatchKey] | None = None,
  231. *,
  232. skip_native_fns_gen: bool = False,
  233. loaded_yaml: object | None = None,
  234. ) -> ParsedYaml:
  235. global _GLOBAL_PARSE_NATIVE_YAML_CACHE
  236. if path not in _GLOBAL_PARSE_NATIVE_YAML_CACHE:
  237. valid_tags = parse_tags_yaml(tags_yaml_path)
  238. # if a loaded yaml is provided, use that instead of reading from path
  239. if loaded_yaml is None:
  240. with open(path) as f:
  241. es = yaml.load(f, Loader=LineLoader)
  242. else:
  243. es = loaded_yaml
  244. _GLOBAL_PARSE_NATIVE_YAML_CACHE[path] = parse_native_yaml_struct(
  245. es,
  246. valid_tags,
  247. ignore_keys,
  248. path=path,
  249. skip_native_fns_gen=skip_native_fns_gen,
  250. )
  251. return _GLOBAL_PARSE_NATIVE_YAML_CACHE[path]
  252. # Some assertions are already performed during parsing, but those are only within a single NativeFunction.
  253. # Assertions here are meant to be performed across NativeFunctions.
  254. def error_check_native_functions(funcs: Sequence[NativeFunction]) -> None:
  255. func_map: dict[OperatorName, NativeFunction] = {}
  256. base_func_map: dict[BaseOperatorName, list[NativeFunction]] = defaultdict(list)
  257. for f in funcs:
  258. func_map[f.func.name] = f
  259. base_func_map[f.func.name.name].append(f)
  260. for f in funcs:
  261. if f.structured_delegate is not None:
  262. delegate_func = func_map.get(f.structured_delegate)
  263. if delegate_func is None:
  264. raise AssertionError(
  265. f"{f.func.name} is marked as a structured_delegate pointing to "
  266. f"{f.structured_delegate}, but {f.structured_delegate} is missing."
  267. )
  268. if not delegate_func.structured:
  269. raise AssertionError(
  270. f"{f.func.name} is marked as a structured_delegate pointing to "
  271. f"{f.structured_delegate}, but {f.structured_delegate} is not marked as structured. "
  272. f"Consider adding 'structured=True' to the delegated operator"
  273. )
  274. # Check for reserved Python keywords
  275. PYTHON_RESERVED_KEYWORDS = set(keyword.kwlist)
  276. # List of pre-existing operators that are known to have reserved keywords
  277. # Exclusion list is used to suppress the assertion for these operators
  278. EXCLUSION_LIST = {
  279. ("_has_compatible_shallow_copy_type", "from"),
  280. ("random_.from", "from"),
  281. ("uniform_", "from"),
  282. }
  283. for arg in f.func.arguments.flat_all:
  284. if arg.name in PYTHON_RESERVED_KEYWORDS:
  285. if (str(f.func.name), arg.name) not in EXCLUSION_LIST:
  286. raise AssertionError(
  287. f"Argument name '{arg.name}' in function '{f.func.name}' is a reserved Python keyword."
  288. )
  289. # See Note [resize_ in Functionalization]
  290. # resize_() is technically an inplace view op (and therefore needs the tag),
  291. # but it would be overkill to add a true "view" variant of resize.
  292. # Instead, resize_() gets special treatment in functionalization,
  293. # and we have a resize() op that is non-aliasing + functional.
  294. if (
  295. "inplace_view" in f.tags
  296. and str(f.func.name) != "resize_"
  297. and str(f.func.name) != "resize_as_"
  298. and str(f.func.name.name) != "set_"
  299. ):
  300. base_name = f.func.name.name
  301. if not base_name.inplace:
  302. raise AssertionError(
  303. f"{f.func.name} is marked with tag: inplace_view, but it doesn't follow the naming "
  304. "convention for inplace ops - the codegen expects the base name to have a trailing underscore."
  305. )
  306. out_of_place_base_name = BaseOperatorName(
  307. base_name.base, False, base_name.dunder_method
  308. )
  309. if len(base_func_map[out_of_place_base_name]) == 0:
  310. raise AssertionError(
  311. f"{f.func.name} is marked with tag: inplace_view. The codegen expects there to be a corresponding "
  312. f"out-of-place view op with the name '{base_name}' and matching schema, but it didn't find one."
  313. )
  314. def cpp_string(s: str) -> str:
  315. """Convert a python string into a c++ string literal"""
  316. s = s.replace("\\", "\\\\")
  317. s = s.replace('"', '\\"')
  318. s = s.replace("\a", "\\a")
  319. s = s.replace("\b", "\\b")
  320. s = s.replace("\f", "\\f")
  321. s = s.replace("\n", "\\n")
  322. s = s.replace("\v", "\\v")
  323. s = s.replace("\t", "\\t")
  324. return f'"{s}"'
  325. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  326. #
  327. # C++ CODE GENERATION
  328. #
  329. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  330. # Most functions in this section are curried: they consist of a function
  331. # that takes some parameters (e.g., what is to be generated) which itself
  332. # returns a function that actually maps NativeFunction to the code
  333. # to be generated. This pattern makes it convenient to use map, concatMap
  334. # and similar functional combinators.
  335. def static_dispatch_keys(backends: list[BackendIndex]) -> list[DispatchKey]:
  336. if len(backends) == 0:
  337. return []
  338. else:
  339. return [backend.dispatch_key for backend in backends] + [
  340. DispatchKey.CompositeImplicitAutograd,
  341. DispatchKey.CompositeImplicitAutogradNestedTensor,
  342. DispatchKey.CompositeExplicitAutograd,
  343. DispatchKey.CompositeExplicitAutogradNonFunctional,
  344. ]
  345. def get_static_dispatch_backend(
  346. f: NativeFunction, backend_index: BackendIndex
  347. ) -> DispatchKey | None:
  348. if f.structured_delegate is not None or backend_index.has_kernel(f):
  349. # TODO: for ops with structured_delegate it should check the dispatch table of
  350. # the out variant instead. For now, these structured ops all have CPU/CUDA kernels
  351. # so we always dispatch to the `backend`, but this could be wrong when we
  352. # migrate math/default_backend ops to use structured delegate.
  353. return backend_index.dispatch_key
  354. elif f.has_composite_explicit_autograd_kernel:
  355. return DispatchKey.CompositeExplicitAutograd
  356. elif f.has_composite_explicit_autograd_non_functional_kernel:
  357. return DispatchKey.CompositeExplicitAutogradNonFunctional
  358. elif f.has_composite_implicit_autograd_kernel:
  359. return DispatchKey.CompositeImplicitAutograd
  360. elif f.has_composite_implicit_autograd_nested_tensor_kernel:
  361. return DispatchKey.CompositeImplicitAutogradNestedTensor
  362. return None
  363. def static_dispatch_ops_header(
  364. f: NativeFunction, backend_index: list[BackendIndex]
  365. ) -> str | None:
  366. if backend_index is None or f.manual_kernel_registration:
  367. return None
  368. output = []
  369. for index in backend_index:
  370. dispatch_key = get_static_dispatch_backend(f, index)
  371. if dispatch_key is not None:
  372. output.append(
  373. f"#include <ATen/ops/{f.root_name}_{dispatch_key.lower()}_dispatch.h>"
  374. )
  375. return "\n".join(output)
  376. def static_dispatch_extra_headers(backends: list[BackendIndex]) -> list[str]:
  377. return [
  378. f"#include <ATen/{dispatch_key}Functions.h>"
  379. for dispatch_key in static_dispatch_keys(backends)
  380. ]
  381. # Translates arguments of `sig` to CppSignature bindings.
  382. # Note that we have a special case for `memory_format` argument and this case is not covered by
  383. # tools.codegen.api.translate() yet as its application is limited to static dispatch.
  384. def translate_args(
  385. sig: CppSignature | DispatcherSignature,
  386. cpp_sig: CppSignature,
  387. ) -> str:
  388. # Adds SpecialArgName.possibly_redundant_memory_format NamedCType for memory_format bindings
  389. def add_spl_memory_format_binding(input_bindings: list[Binding]) -> list[Binding]:
  390. output_bindings: list[Binding] = []
  391. for binding in input_bindings:
  392. if binding.name == "memory_format":
  393. spl_mem_format_binding = Binding(
  394. nctype=NamedCType(
  395. SpecialArgName.possibly_redundant_memory_format,
  396. binding.nctype.type,
  397. ),
  398. name=binding.name,
  399. default=binding.default,
  400. argument=binding.argument,
  401. )
  402. output_bindings.append(spl_mem_format_binding)
  403. else:
  404. output_bindings.append(binding)
  405. return output_bindings
  406. src_bindings = list(sig.arguments())
  407. goal_bindings = list(cpp_sig.arguments())
  408. # When last argument of CPP signature has SpecialArgName.possibly_redundant_memory_format NCType,
  409. # get memory_format bindings of dispatcher signature to have the same NCType as well
  410. for arg in goal_bindings:
  411. if arg.nctype.name == SpecialArgName.possibly_redundant_memory_format:
  412. src_bindings = add_spl_memory_format_binding(src_bindings)
  413. break
  414. exprs = translate(src_bindings, goal_bindings)
  415. return ", ".join(a.expr for a in exprs)
  416. def generate_static_dispatch_backend_call(
  417. sig: CppSignature | DispatcherSignature,
  418. f: NativeFunction,
  419. backend_index: BackendIndex,
  420. ) -> str:
  421. cpp_sig = gen_static_dispatch_backend_call_signature(sig, f)
  422. name = cpp_sig.name()
  423. exprs = translate_args(sig, cpp_sig)
  424. backend_metadata = backend_index.get_kernel(f)
  425. kernel_ns = (
  426. backend_metadata.cpp_namespace
  427. if backend_metadata and backend_metadata.cpp_namespace
  428. else DEFAULT_KERNEL_NAMESPACE
  429. )
  430. ns = kernel_ns.replace("::native", "")
  431. return f"return {ns}::{backend_index.dispatch_key.lower()}::{name}({exprs});"
  432. def generate_static_dispatch_fallback_call(
  433. sig: CppSignature | DispatcherSignature,
  434. f: NativeFunction,
  435. backend_indices: list[BackendIndex],
  436. ) -> str:
  437. cpp_sigs = CppSignatureGroup.from_native_function(
  438. f, method=False, fallback_binding=False
  439. )
  440. if sig.symint and f.func.has_symint():
  441. cpp_sig = cpp_sigs.symint_signature
  442. else:
  443. cpp_sig = cpp_sigs.signature
  444. if cpp_sig is None:
  445. raise AssertionError("Expected cpp_sig to be non-None")
  446. name = cpp_sig.name()
  447. exprs = translate_args(sig, cpp_sig)
  448. ns = DEFAULT_KERNEL_NAMESPACE.replace("::native", "")
  449. if f.has_composite_explicit_autograd_kernel:
  450. return f"return {ns}::{DispatchKey.CompositeExplicitAutograd.lower()}::{name}({exprs});"
  451. elif f.has_composite_explicit_autograd_non_functional_kernel:
  452. return f"return {ns}::{DispatchKey.CompositeExplicitAutogradNonFunctional.lower()}::{name}({exprs});"
  453. elif f.has_composite_implicit_autograd_kernel:
  454. return f"return {ns}::{DispatchKey.CompositeImplicitAutograd.lower()}::{name}({exprs});"
  455. elif f.has_composite_implicit_autograd_nested_tensor_kernel:
  456. return f"return {ns}::{DispatchKey.CompositeImplicitAutogradNestedTensor.lower()}::{name}({exprs});"
  457. else:
  458. return f"""TORCH_CHECK(false, "Static dispatch does not support {name} for\
  459. {", ".join([str(index.dispatch_key) for index in backend_indices])} ");"""
  460. def static_dispatch(
  461. sig: CppSignature | DispatcherSignature,
  462. f: NativeFunction,
  463. backend_indices: list[BackendIndex],
  464. ) -> str:
  465. """
  466. For a given `NativeFunction`, find out the corresponding backend and dispatch to it. If more than one
  467. backends exist, fallback to static dispatch by determining dispatch key from inputs.
  468. Arguments:
  469. sig: A CppSignature or DispatcherSignature for this native function we want to use.
  470. f: NativeFunction to generate static dispatch.
  471. backend_indices: All available backends.
  472. Return:
  473. C++ code to call backend-specific functions, e.g., "return at::cpu::add(self, other, scale);"
  474. """
  475. if len(backend_indices) == 0 or f.manual_kernel_registration:
  476. return ""
  477. keys = [
  478. b
  479. for b in backend_indices
  480. if b.has_kernel(f)
  481. or (
  482. f.structured_delegate is not None
  483. and b.dispatch_key in STRUCTURED_DISPATCH_KEYS
  484. )
  485. ]
  486. if len(keys) == 1:
  487. return generate_static_dispatch_backend_call(sig, f, keys[0])
  488. elif len(keys) == 0:
  489. return generate_static_dispatch_fallback_call(sig, f, backend_indices)
  490. native_tensor_args = [
  491. a.name
  492. for a in sig.arguments()
  493. if isinstance(a.argument, SelfArgument)
  494. or isinstance(a.argument, Argument)
  495. and a.argument.type.is_tensor_like()
  496. ]
  497. tensor_args = ", ".join(native_tensor_args)
  498. tensor_opts = f.func.arguments.tensor_options
  499. stmts = []
  500. subexprs: list[str] = []
  501. if tensor_opts is not None:
  502. subexprs.append(
  503. "DispatchKeySet(c10::computeDispatchKey(dtype, layout, device))"
  504. )
  505. if tensor_args != "":
  506. subexprs.append(f"c10::detail::multi_dispatch_key_set({tensor_args})")
  507. stmts.append(f"""DispatchKeySet _dk_set = {" | ".join(subexprs)};""")
  508. stmts.append("DispatchKey _dk = c10::highestPriorityBackendTypeId(_dk_set);")
  509. dispatch_code = []
  510. for index in keys:
  511. dispatch_code.append(f"""case DispatchKey::{index.dispatch_key}:""")
  512. dispatch_code.append(
  513. f"""\t{generate_static_dispatch_backend_call(sig, f, index)};"""
  514. )
  515. fallback = generate_static_dispatch_fallback_call(sig, f, backend_indices)
  516. connector = "\n\t\t"
  517. return f"""
  518. {connector.join(stmts)}
  519. switch (_dk) {{
  520. {connector.join(dispatch_code)}
  521. default:
  522. {fallback}
  523. }}
  524. """
  525. # Generates RegisterSchema.cpp. Depending on the selector, either
  526. # all schemas are registered, or only some are (in the case of
  527. # selective build)
  528. @dataclass(frozen=True)
  529. class RegisterSchema:
  530. selector: SelectiveBuilder
  531. known_tags: dict[str, int] = field(default_factory=dict)
  532. @method_with_native_function
  533. def __call__(self, f: NativeFunction) -> str | None:
  534. if not self.selector.is_native_function_selected(f):
  535. return None
  536. tags = "{" + ", ".join(f"at::Tag::{tag}" for tag in sorted(f.tags)) + "}"
  537. if tags == "{}":
  538. return f"m.def({cpp_string(str(f.func))}, {{}});\n"
  539. maybe_tags = ""
  540. if tags not in self.known_tags:
  541. idx = len(self.known_tags)
  542. self.known_tags[tags] = idx
  543. maybe_tags = f"const std::vector<at::Tag> tags_{idx} = {tags};\n"
  544. return f"{maybe_tags}m.def({cpp_string(str(f.func))}, tags_{self.known_tags[tags]});\n"
  545. # Generates Operators.h and Operators.cpp.
  546. # These provide macros that, given an operator and overload name, allow users
  547. # to access an "un-overloaded" function version of the operator. This
  548. # is useful for extension writers who want to (1) want to decltype the operator
  549. # and (2) don't want to worry about method-only operators.
  550. @dataclass(frozen=True)
  551. class ComputeOperators:
  552. target: Literal[Target.DECLARATION, Target.DEFINITION]
  553. static_dispatch_backend_indices: list[BackendIndex]
  554. @method_with_native_function
  555. def __call__(self, f: NativeFunction) -> str:
  556. sig = DispatcherSignature.from_schema(f.func)
  557. name = f.func.name.unambiguous_name()
  558. if self.target is Target.DECLARATION:
  559. # Note [The ATen Operators API]
  560. # The ATen Operators API lives in the at::_ops namespace, and contains compile-time
  561. # metadata about each operator + entry points into the Dispatcher.
  562. # The C++ function, method, and redispatch API's are all implemented as wrappers
  563. # into various bits of the structs defined here.
  564. #
  565. # Important characteristics about the Operators API:
  566. # (1) It follows the Dispatcher API.
  567. # This is kind of necessary to avoid overhead.
  568. # For example: if it followed the C++ API, then all of the faithful C++ factory functions
  569. # would need to wrap their arguments into TensorOptions only to unwrap them again.
  570. # (2) Overload names are disambiguated.
  571. # This is helpful for pytorch extenders who would like to decltype() an aten operator,
  572. # that has overloads, e.g. decltype(at::_ops::mul_Tensor::call)
  573. # (3) No argument defaulting is allowed.
  574. # This is more of an implementation detail to avoid #include cycles,
  575. # since TensorBody.h (which defines the Tensor class) needs to include this file.
  576. # (4) manual_cpp_bindings and faithful names are not included in the API.
  577. # This applies to stuff like __dispatch__is_complex(), and add_outf().
  578. # These aren't "real aten ops", they're just additional functions provided by the C++ API.
  579. # They're implemented as wrappers in Functions.h that call into the actual operators
  580. # defined here, i.e. at::_ops::is_complex::call() and at::_ops::add_out::call().
  581. # This means that ATEN_OP(is_complex) will not fastpath, and will go through the dispatcher.
  582. return f"""
  583. struct TORCH_API {name} {{
  584. using schema = {sig.type()};
  585. using ptr_schema = schema*;
  586. // See Note [static constexpr char* members for windows NVCC]
  587. static constexpr const char* name = "aten::{f.func.name.name}";
  588. static constexpr const char* overload_name = "{f.func.name.overload_name}";
  589. static constexpr const char* schema_str = {cpp_string(str(f.func))};
  590. static {sig.defn(name="call", is_redispatching_fn=False)};
  591. static {sig.defn(name="redispatch", is_redispatching_fn=True)};
  592. }};"""
  593. elif self.target is Target.DEFINITION:
  594. defns = f"""
  595. // aten::{f.func}
  596. static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{
  597. return c10::Dispatcher::singleton()
  598. .findSchemaOrThrow({name}::name, {name}::overload_name)
  599. .typed<{name}::schema>();
  600. }}
  601. """
  602. for is_redispatching_fn in [False, True]:
  603. if is_redispatching_fn:
  604. dispatcher_exprs_str = ", ".join(
  605. ["dispatchKeySet"] + [a.name for a in sig.arguments()]
  606. )
  607. method_base = "redispatch"
  608. else:
  609. dispatcher_exprs_str = ", ".join([a.name for a in sig.arguments()])
  610. method_base = "call"
  611. dispatcher_call = method_base
  612. method_name = f"{name}::{method_base}"
  613. fn_body = f"""
  614. static auto op = create_{name}_typed_handle();
  615. return op.{dispatcher_call}({dispatcher_exprs_str});"""
  616. if (
  617. not is_redispatching_fn
  618. and len(self.static_dispatch_backend_indices) > 0
  619. ):
  620. # call() should go through static dispatch
  621. fn_body = static_dispatch(
  622. sig, f, backend_indices=self.static_dispatch_backend_indices
  623. )
  624. defns += f"""
  625. // aten::{f.func}
  626. {sig.defn(name=method_name, is_redispatching_fn=is_redispatching_fn)} {{
  627. {fn_body}
  628. }}
  629. """
  630. return defns
  631. else:
  632. assert_never(self.target)
  633. # Generates Functions.h, which provides the functional public C++ API,
  634. # and the scaffolding to call into the dispatcher from these functions.
  635. @dataclass(frozen=True)
  636. class ComputeFunction:
  637. @method_with_native_function
  638. def __call__(self, f: NativeFunction) -> str | None:
  639. sig_group = CppSignatureGroup.from_native_function(
  640. f, method=False, fallback_binding=f.manual_cpp_binding
  641. )
  642. has_symint = f.func.has_symint()
  643. result = ""
  644. for sig in sig_group.signatures():
  645. # See Note [The ATen Operators API]
  646. target_sig = DispatcherSignature.from_schema(f.func)
  647. exprs = translate(sig.arguments(), target_sig.arguments())
  648. exprs_str = ", ".join([e.expr for e in exprs])
  649. if sig.symint:
  650. intlike_t = "c10::SymInt"
  651. else:
  652. intlike_t = "int64_t"
  653. if Variant.function in f.variants:
  654. result += f"""
  655. // aten::{f.func}
  656. inline {sig.decl()} {{
  657. return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
  658. }}"""
  659. # The template function can be used from template situations
  660. # where you want to switch between the symint or not version
  661. # depending on a template argument
  662. #
  663. # NB: we ALWAYS generate this even for methods. But we put it in
  664. # this header so it can take advantage of per-op headers
  665. if has_symint:
  666. result += f"""
  667. namespace symint {{
  668. template <typename T, typename = std::enable_if_t<std::is_same_v<T, {intlike_t}>>>
  669. {sig.decl(suppress_symint_suffix=True)} {{
  670. return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
  671. }}
  672. }}
  673. """
  674. return result
  675. # Generates TensorBody.h. This file provides the object-oriented (method-based)
  676. # public C++ API, and the scaffolding to call into the dispatcher from these functions.
  677. @dataclass(frozen=True)
  678. class ComputeTensorMethod:
  679. target: Literal[Target.DECLARATION, Target.DEFINITION]
  680. static_dispatch_backend_indices: list[BackendIndex]
  681. @method_with_native_function
  682. def __call__(self, f: NativeFunction) -> str | None:
  683. if Variant.method not in f.variants:
  684. return None
  685. if f.func.is_out_fn():
  686. raise AssertionError(f"Method variant cannot be an out function: {f.func}")
  687. if f.func.arguments.self_arg is None:
  688. raise AssertionError(f"Method variant must have self_arg: {f.func}")
  689. sig_group = CppSignatureGroup.from_native_function(
  690. f, method=True, fallback_binding=f.manual_cpp_binding
  691. )
  692. if self.target is Target.DECLARATION:
  693. result = ""
  694. for sig in sig_group.signatures():
  695. result += f"{sig.decl()} const;\n"
  696. return result
  697. if self.target is not Target.DEFINITION:
  698. assert_never(self.target)
  699. result = ""
  700. for sig in sig_group.signatures():
  701. target_sig = DispatcherSignature.from_schema(f.func)
  702. exprs = translate(sig.arguments(), target_sig.arguments(), method=True)
  703. exprs_str = ", ".join([e.expr for e in exprs])
  704. result += f"""
  705. // aten::{f.func}
  706. inline {sig.defn(prefix="Tensor::")} const {{
  707. return at::_ops::{f.func.name.unambiguous_name()}::call({exprs_str});
  708. }}
  709. """
  710. return result
  711. # Generates RedispatchFunctions.h.
  712. # This is similar to the C++ API defined in Functions.h, but provides access
  713. # to the dispatcher's redispatch API.
  714. @dataclass(frozen=True)
  715. class ComputeRedispatchFunction:
  716. @method_with_native_function
  717. def __call__(self, f: NativeFunction) -> str | None:
  718. # We unconditionally generate function variants of the redispatch API.
  719. # This is mainly because we can namespace functions separately, but not methods,
  720. sig_group = CppSignatureGroup.from_native_function(
  721. f, method=False, fallback_binding=f.manual_cpp_binding
  722. )
  723. result = ""
  724. for sig in sig_group.signatures():
  725. target_sig = DispatcherSignature.from_schema(f.func)
  726. exprs = translate(sig.arguments(), target_sig.arguments())
  727. exprs_str = ", ".join(["dispatchKeySet"] + [a.expr for a in exprs])
  728. result += f"""
  729. // aten::{f.func}
  730. inline {sig.decl(is_redispatching_fn=True)} {{
  731. return at::_ops::{f.func.name.unambiguous_name()}::redispatch({exprs_str});
  732. }}
  733. """
  734. return result
  735. # Generates ATenOpList.cpp, a runtime accessible list of all aten
  736. # operators.
  737. # TODO: This was historically used to help some JIT interop code
  738. # figure out whether or not to treat aten namespace'd operators
  739. # one way or another, we should reevaluate if this is actually needed.
  740. @with_native_function
  741. def compute_aten_op(f: NativeFunction) -> str:
  742. return f'{{"aten::{f.func.name.name}", "{f.func.name.overload_name}"}},'
  743. # Generates MetaFunctions.h
  744. def compute_meta_function_declaration(g: NativeFunctionsGroup) -> str | None:
  745. if not g.structured:
  746. return None
  747. with native_function_manager(g.out):
  748. name = meta.name(g)
  749. args = structured.meta_arguments(g)
  750. args_str = ", ".join(a.decl() for a in args)
  751. parent_class = g.out.structured_inherits
  752. if parent_class is None:
  753. parent_class = "at::impl::MetaBase"
  754. meta_return = "void"
  755. precomputed = g.out.precomputed if g.structured else None
  756. if precomputed:
  757. # Generate the template declaration with one bool parameter for each
  758. # precomputed element. Each parameter is true if the corresponding (in
  759. # terms of position) precomputed element has been set.
  760. precomputed_values = [*precomputed.replace.values(), precomputed.add]
  761. precomputed_elements = [
  762. elem for replace_list in precomputed_values for elem in replace_list
  763. ]
  764. precomputed_template_parameters = [
  765. elem.name.upper() for elem in precomputed_elements
  766. ]
  767. precomputed_template_params_str = ", ".join(
  768. f"bool {param} = false" for param in precomputed_template_parameters
  769. )
  770. precompute_template_decl = f"template <{precomputed_template_params_str}>"
  771. # Generate a string containing declarations of all precomputed elements.
  772. precomputed_elements_with_cpp_types = [
  773. structured.argument_type(elem, binds=elem.name)
  774. for elem in precomputed_elements
  775. ]
  776. precomputed_elements_decl = ";\n".join(
  777. f"{elem.cpp_type(strip_ref=True)} {elem.name}"
  778. for elem in precomputed_elements_with_cpp_types
  779. )
  780. # Generate "setter" methods for each precomputed element. Each method will return
  781. # a new instance of precompute_out with the template parameter that corresponds to
  782. # the member set by the method to true (to indicate that it has been set).
  783. setter_methods = []
  784. for i, elem in enumerate(precomputed_elements):
  785. # Generate the signature. The return type will be the same
  786. # as the type of `this` but with the template parameter
  787. # corresponding to the element set by this method set to true.
  788. # The assert generated below will ensure that this template
  789. # parameter is false on the type of `this`.
  790. return_ty_templates = ", ".join(
  791. precomputed_template_parameters[:i]
  792. + ["true"]
  793. + precomputed_template_parameters[i + 1 :]
  794. )
  795. return_ty = f"precompute_out<{return_ty_templates}>"
  796. elem_cpp_ty = precomputed_elements_with_cpp_types[i].cpp_type(
  797. strip_ref=True
  798. )
  799. signature = f"{return_ty} set_{elem.name}({elem_cpp_ty} value)"
  800. # Generate an assert which checks that the
  801. # template parameter corresponding to the precomputed
  802. # element that is set by this method is false on the
  803. # class corresponding to the object that `this` points to.
  804. # This ensures that each element can be set only once.
  805. assert_msg = f'"{elem.name} already set"'
  806. assert_stmt = f"static_assert({precomputed_template_parameters[i]} == false, {assert_msg});"
  807. # Generate the new object construction block. All state
  808. # except the element that this method sets is copied from the
  809. # object that `this` points to. The value for the element that
  810. # the method sets is taken from a method parameter.
  811. construction_stmts = []
  812. construction_stmts.append(f"{return_ty} ret;")
  813. for j, elem in enumerate(precomputed_elements):
  814. if i == j:
  815. construction_stmts.append(f"ret.{elem.name} = value;")
  816. else:
  817. construction_stmts.append(
  818. f"ret.{elem.name} = this->{elem.name};"
  819. )
  820. construction_stmts.append("return ret;")
  821. construction_block = "\n".join(construction_stmts)
  822. setter_methods.append(
  823. f"""
  824. {signature} {{
  825. {assert_stmt}
  826. {construction_block}
  827. }}
  828. """
  829. )
  830. setter_methods_decl = "\n".join(setter_methods)
  831. # Meta should return an instance of the struct containing the precomputed elements.
  832. meta_return_template_params = ", ".join(
  833. ["true"] * len(precomputed_template_parameters)
  834. )
  835. # This typedef (actually a using statement) is needed so that TORCH_META_FUNC can reuse the return
  836. # type (which has a variable number of template parameters).
  837. meta_return_typedef = f"using meta_return_ty = precompute_out <{meta_return_template_params}>;"
  838. meta_return = "meta_return_ty"
  839. precomputed_decl = f"""
  840. {precompute_template_decl}
  841. struct TORCH_API precompute_out {{
  842. {setter_methods_decl}
  843. {precomputed_elements_decl};
  844. }};"""
  845. else:
  846. meta_return_typedef = ""
  847. precomputed_decl = ""
  848. return f"""\
  849. struct TORCH_API structured_{name} : public {parent_class} {{
  850. {precomputed_decl}
  851. {meta_return_typedef}
  852. {meta_return} meta({args_str});
  853. }};
  854. """
  855. def needs_backend_select(f: NativeFunction, selector: SelectiveBuilder) -> bool:
  856. name = str(f.func.name.name)
  857. if name.endswith("_like") or name.startswith("new_"):
  858. return False
  859. if f.func.arguments.tensor_options is None:
  860. return False
  861. return selector.is_native_function_selected(f)
  862. # Generates RegisterBackendSelect.cpp, a series of kernels which provide
  863. # specialized computation of dispatch key for operator signatures which cannot
  864. # be easily done automatically using templating.
  865. @dataclass(frozen=True)
  866. class ComputeBackendSelect:
  867. target: Literal[Target.DEFINITION, Target.REGISTRATION]
  868. # Selector object to determine which operators to generate
  869. # registration code for.
  870. selector: SelectiveBuilder
  871. @method_with_native_function
  872. def __call__(self, f: NativeFunction) -> str | None:
  873. if not needs_backend_select(f, self.selector):
  874. return None
  875. name = native.name(f.func)
  876. # BackendSelect can go to Meta, so it must preserve symints
  877. native_sig = NativeSignature(f.func, symint=True)
  878. native_tensor_args = [
  879. a
  880. for a in native_sig.arguments()
  881. if isinstance(a.argument, Argument) and a.argument.type.is_tensor_like()
  882. ]
  883. dispatcher_sig = DispatcherSignature.from_schema(f.func)
  884. sig: NativeSignature | DispatcherSignature
  885. sig = dispatcher_sig
  886. dispatcher_exprs = dispatcher_sig.exprs()
  887. dispatch_key = "c10::computeDispatchKey(dtype, layout, device)"
  888. if self.target is Target.DEFINITION:
  889. # I don't think there's actually a good reason to generate
  890. # these two cases differently
  891. # The first case could probably be improved though- it calls computeDispatchKeySet(),
  892. # which looks at TLS dispatch keys- there should not be any by the time we reach backend select.
  893. if native_tensor_args:
  894. if not f.func.arguments.has_tensor_arg():
  895. raise AssertionError(
  896. f"Expected function to have tensor args: {f.func}"
  897. )
  898. tensor_args = ", ".join(a.name for a in native_tensor_args)
  899. compute_dk = f"""\
  900. DispatchKeySet _dk_set = c10::DispatchKeySet({dispatch_key}) | c10::detail::multi_dispatch_key_set({tensor_args});
  901. DispatchKeySet _dk_mask = c10::DispatchKeySet(DispatchKeySet::FULL_AFTER, DispatchKey::BackendSelect);
  902. DispatchKeySet _dk = c10::impl::computeDispatchKeySet(_dk_set, _dk_mask);"""
  903. else:
  904. if f.func.arguments.has_tensor_arg():
  905. raise AssertionError(
  906. f"Expected function to not have tensor args: {f.func}"
  907. )
  908. compute_dk = (
  909. f"DispatchKeySet _dk = c10::DispatchKeySet({dispatch_key});"
  910. )
  911. return f"""\
  912. // aten::{f.func}
  913. C10_ALWAYS_INLINE
  914. {sig.defn(name)} {{
  915. {compute_dk}
  916. return at::_ops::{f.func.name.unambiguous_name()}::redispatch(
  917. _dk, {", ".join(a.expr for a in dispatcher_exprs)});
  918. }}
  919. """
  920. elif self.target is Target.REGISTRATION:
  921. return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
  922. else:
  923. assert_never(self.target)
  924. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  925. #
  926. # YAML CODE GENERATION
  927. #
  928. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  929. def format_yaml(data: object) -> str:
  930. # Ignore alias in Dumper
  931. YamlDumper.ignore_aliases = lambda self, data: True # type: ignore[assignment]
  932. # Support serializing OrderedDict
  933. def dict_representer(dumper: Any, data: Any) -> Any:
  934. return dumper.represent_dict(data.items())
  935. YamlDumper.add_representer(OrderedDict, dict_representer) # type: ignore[no-untyped-call]
  936. # Some yaml parsers (e.g. Haskell's) don't understand line breaks.
  937. # width=1e9 turns off optional line breaks and improves
  938. # the portability of the outputted yaml.
  939. return yaml.dump(data, default_flow_style=False, Dumper=YamlDumper, width=1e9) # type: ignore[no-any-return, call-overload]
  940. # For some reason, some defaults we write to YAML are written as native
  941. # YAML objects, rather than doing them uniformly as strings. This
  942. # function detects those cases and converts them into native Python
  943. # objects.
  944. def pythonify_default(s: str) -> object:
  945. if s == "true":
  946. return True
  947. elif s == "false":
  948. return False
  949. try:
  950. return int(s)
  951. except ValueError:
  952. try:
  953. return float(s)
  954. except ValueError:
  955. return s
  956. # What is a dynamic type? Over time, the semantic meaning of
  957. # dynamic type has degraded to meaninglessness (in the old days,
  958. # it captured dtype-ness of types, but that has gone away with
  959. # the removal of TH). These days, it's mostly the same thing as
  960. # the C++ API argument type, except that Tensor and Tensor?
  961. # arguments simply present as Tensor.
  962. #
  963. # TODO: Get rid of dynamic_type, after getting tools/autograd
  964. # to use the new codegen framework
  965. def dynamic_type(t: Type) -> str:
  966. if isinstance(t, OptionalType):
  967. return dynamic_type(t.elem)
  968. # Note we don't use t.is_tensor_like() here because it would
  969. # also include Tensor[]
  970. if str(t) == "Tensor":
  971. return "at::Tensor"
  972. # This is a legacy concept, so never report SymInt
  973. return cpp.argumenttype_type(
  974. t, mutable=False, binds="__placeholder__", symint=False
  975. ).cpp_type()
  976. def compute_method_of_yaml(variants: set[Variant]) -> list[str]:
  977. # This is written out explicitly to ensure that Tensor and
  978. # namespace are put into the list in the right order
  979. method_of = ["Type"]
  980. if Variant.method in variants:
  981. method_of.append("Tensor")
  982. if Variant.function in variants:
  983. method_of.append("namespace")
  984. return method_of
  985. def compute_returns_yaml(
  986. f: NativeFunction,
  987. ) -> tuple[list[dict[str, str]], dict[str, str]]:
  988. # Note [name and field_name]
  989. # ~~~~~~~~~~~~~~~~~~~~~~~~~~
  990. # To understand name_to_field_name, we must first talk about this
  991. # schema:
  992. #
  993. # lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR)
  994. #
  995. # There is something very odd about this schema: it is an out
  996. # variant of the function (that is to say, it will convert into
  997. # at::lstsq_out() in the C++ API), but the names of the output
  998. # return arguments don't match the keyword argument names of
  999. # the inputs. It TURNS OUT that in this situation, the historical
  1000. # Declarations.yaml we want to output is this (abbreviated to
  1001. # only show relevant fields):
  1002. #
  1003. # arguments:
  1004. # ...
  1005. # - field_name: solution
  1006. # name: X
  1007. # - field_name: QR
  1008. # name: qr
  1009. # ...
  1010. #
  1011. # returns:
  1012. # - field_name: solution
  1013. # name: X
  1014. # - field_name: QR
  1015. # name: qr
  1016. #
  1017. # The name of the return fields is stored in 'field_name', and the
  1018. # name of the arguments is stored in 'name'. So when we process
  1019. # arguments, we need a way to get at the corresponding return. At
  1020. # the moment, this is most conveniently done by constructing a
  1021. # mapping from name (the argument concept) to field_name (the
  1022. # return concept) while processing return arguments, since we don't
  1023. # directly maintain this correspondence in the modeling of function
  1024. # schema itself.
  1025. #
  1026. # See also https://github.com/pytorch/pytorch/issues/43114
  1027. name_to_field_name: dict[str, str] = {}
  1028. # Compute the returns field of the YAML entry
  1029. names = cpp.return_names(f)
  1030. returns = []
  1031. for i, (r, name) in enumerate(zip(f.func.returns, names)):
  1032. ret = {
  1033. "dynamic_type": dynamic_type(r.type),
  1034. "name": name,
  1035. # legacy, report ints
  1036. "type": cpp.return_type(r, symint=False).cpp_type(),
  1037. }
  1038. if r.name:
  1039. # See Note [name and field_name]
  1040. ret["field_name"] = r.name
  1041. if f.func.is_out_fn():
  1042. name_to_field_name[f.func.arguments.out[i].name] = r.name
  1043. returns.append(ret)
  1044. return returns, name_to_field_name
  1045. # arguments in yaml roughly corresponds to the public C++ API
  1046. def compute_cpp_argument_yaml(
  1047. cpp_a: Binding,
  1048. *,
  1049. schema_order: bool,
  1050. kwarg_only_set: set[str],
  1051. out_arg_set: set[str],
  1052. name_to_field_name: dict[str, str],
  1053. ) -> object:
  1054. if isinstance(cpp_a.argument, TensorOptionsArguments):
  1055. arg: dict[str, object] = {
  1056. "annotation": None,
  1057. "dynamic_type": "at::TensorOptions",
  1058. "is_nullable": False,
  1059. "name": cpp_a.name,
  1060. "type": cpp_a.type,
  1061. "kwarg_only": True,
  1062. }
  1063. if cpp_a.default is not None:
  1064. arg["default"] = cpp_a.default
  1065. return arg
  1066. elif isinstance(cpp_a.argument, SelfArgument):
  1067. raise AssertionError
  1068. elif isinstance(cpp_a.argument, Argument):
  1069. return compute_argument_yaml(
  1070. cpp_a.argument,
  1071. schema_order=schema_order,
  1072. kwarg_only_set=kwarg_only_set,
  1073. out_arg_set=out_arg_set,
  1074. name_to_field_name=name_to_field_name,
  1075. )
  1076. def compute_argument_yaml(
  1077. a: Argument,
  1078. *,
  1079. schema_order: bool,
  1080. kwarg_only_set: set[str],
  1081. out_arg_set: set[str],
  1082. name_to_field_name: dict[str, str],
  1083. ) -> object:
  1084. arg: dict[str, object] = {
  1085. "annotation": str(a.annotation) if a.annotation else None,
  1086. "dynamic_type": dynamic_type(a.type),
  1087. "is_nullable": a.type.is_nullable(),
  1088. "name": a.name,
  1089. # legacy, report ints
  1090. "type": cpp.argument_type(a, binds="__placeholder__", symint=False).cpp_type(),
  1091. }
  1092. if a.default is not None:
  1093. arg["default"] = pythonify_default(
  1094. cpp.default_expr(a.default, a.type, symint=False)
  1095. )
  1096. if a.name in kwarg_only_set:
  1097. arg["kwarg_only"] = True
  1098. if a.name in out_arg_set:
  1099. arg["output"] = True
  1100. arg["allocate"] = True
  1101. # See Note [name and field_name]
  1102. if a.name in name_to_field_name:
  1103. arg["field_name"] = name_to_field_name[a.name]
  1104. # Historically, booleans don't get their size recorded, because it
  1105. # is already built into the cpp type (e.g., std::array<bool, 4>)
  1106. l = a.type.is_list_like()
  1107. if l is not None and l.size is not None and str(l.elem) != "bool":
  1108. arg["size"] = l.size
  1109. return arg
  1110. @with_native_function
  1111. def compute_declaration_yaml(f: NativeFunction) -> object:
  1112. returns, name_to_field_name = compute_returns_yaml(f)
  1113. # These sets are used to conveniently test if an argument is a
  1114. # kwarg-only or out argument
  1115. kwarg_only_set = {a.name for a in f.func.arguments.flat_kwarg_only}
  1116. out_arg_set = {a.name for a in f.func.arguments.out}
  1117. sig_group = CppSignatureGroup.from_native_function(
  1118. f, method=False, fallback_binding=False
  1119. )
  1120. cpp_args = sig_group.signature.arguments()
  1121. arguments = [
  1122. compute_cpp_argument_yaml(
  1123. cpp_a,
  1124. schema_order=False,
  1125. kwarg_only_set=kwarg_only_set,
  1126. out_arg_set=out_arg_set,
  1127. name_to_field_name=name_to_field_name,
  1128. )
  1129. for cpp_a in cpp_args
  1130. ]
  1131. schema_order_jit_arguments = list(f.func.schema_order_arguments())
  1132. schema_order_arguments = [
  1133. compute_argument_yaml(
  1134. a,
  1135. schema_order=True,
  1136. kwarg_only_set=kwarg_only_set,
  1137. out_arg_set=out_arg_set,
  1138. name_to_field_name=name_to_field_name,
  1139. )
  1140. for a in schema_order_jit_arguments
  1141. ]
  1142. cpp_schema_order_types = [
  1143. # NB: method here doesn't matter
  1144. r.type
  1145. for a in schema_order_jit_arguments
  1146. for r in cpp.argument(
  1147. a,
  1148. method=False,
  1149. cpp_no_default_args=set(),
  1150. faithful=False,
  1151. symint=False,
  1152. has_tensor_options=False,
  1153. )
  1154. ]
  1155. # legacy, report ints
  1156. cpp_returns = cpp.returns_type(f.func.returns, symint=False).cpp_type()
  1157. schema_order_cpp_signature = f"{cpp_returns} ({', '.join(cpp_schema_order_types)})"
  1158. is_factory_method = (
  1159. any(isinstance(a.argument, TensorOptionsArguments) for a in cpp_args)
  1160. and Variant.method not in f.variants
  1161. )
  1162. return OrderedDict(
  1163. [
  1164. ("name", cpp.name(f.func)),
  1165. ("operator_name", str(f.func.name.name)),
  1166. ("overload_name", str(f.func.name.overload_name)),
  1167. ("manual_kernel_registration", f.manual_kernel_registration),
  1168. (
  1169. "category_override",
  1170. f.category_override if f.category_override is not None else "",
  1171. ),
  1172. ("schema_string", f"aten::{f.func}"),
  1173. ("arguments", arguments),
  1174. ("schema_order_cpp_signature", schema_order_cpp_signature),
  1175. ("schema_order_arguments", schema_order_arguments),
  1176. ("method_of", compute_method_of_yaml(f.variants)),
  1177. ("mode", "native"),
  1178. ("python_module", "" if f.python_module is None else f.python_module),
  1179. ("returns", returns),
  1180. ("inplace", f.func.name.name.inplace),
  1181. ("is_factory_method", is_factory_method),
  1182. ("abstract", f.is_abstract),
  1183. ("device_guard", f.device_guard),
  1184. ("with_gil", False),
  1185. ("deprecated", False),
  1186. ("has_math_kernel", f.has_composite_implicit_autograd_kernel),
  1187. ]
  1188. )
  1189. # See Note [Auto generated composite kernels]
  1190. def has_autogenerated_composite_kernel(f: NativeFunction) -> bool:
  1191. return (f.structured or f.structured_delegate is not None) and (
  1192. f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace
  1193. )
  1194. @with_native_function_and_indices
  1195. def compute_registration_declarations(
  1196. f: NativeFunction, backend_indices: dict[DispatchKey, BackendIndex]
  1197. ) -> str:
  1198. name = dispatcher.name(f.func)
  1199. returns_type = dispatcher.returns_type(f.func.returns).cpp_type()
  1200. args = dispatcher.arguments(f.func)
  1201. args_str = ", ".join(a.no_default().decl() for a in args)
  1202. comment_data: dict[str, str] = {
  1203. "schema": f"aten::{f.func}",
  1204. # TODO: What exactly is the semantics of the 'dispatch' field?
  1205. "dispatch": str(
  1206. {k for k, v in backend_indices.items() if v.has_kernel(f)}
  1207. != {DispatchKey.CompositeImplicitAutograd}
  1208. and {k for k, v in backend_indices.items() if v.has_kernel(f)}
  1209. != {
  1210. DispatchKey.CompositeImplicitAutograd,
  1211. DispatchKey.CompositeImplicitAutogradNestedTensor,
  1212. }
  1213. ),
  1214. "default": str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)),
  1215. }
  1216. return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)}
  1217. """
  1218. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  1219. #
  1220. # RUN IT ALL
  1221. #
  1222. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
  1223. def get_custom_build_selector(
  1224. provided_op_registration_allowlist: list[str] | None,
  1225. op_selection_yaml_path: str | None,
  1226. ) -> SelectiveBuilder:
  1227. if (
  1228. provided_op_registration_allowlist is not None
  1229. and op_selection_yaml_path is not None
  1230. ):
  1231. raise AssertionError(
  1232. "Both provided_op_registration_allowlist and op_selection_yaml_path "
  1233. "can NOT be provided at the same time."
  1234. )
  1235. op_registration_allowlist: set[str] | None = None
  1236. if provided_op_registration_allowlist is not None:
  1237. op_registration_allowlist = set(provided_op_registration_allowlist)
  1238. if op_registration_allowlist is not None:
  1239. selector = SelectiveBuilder.from_legacy_op_registration_allow_list(
  1240. op_registration_allowlist,
  1241. True,
  1242. False,
  1243. )
  1244. elif op_selection_yaml_path is not None:
  1245. selector = SelectiveBuilder.from_yaml_path(op_selection_yaml_path)
  1246. else:
  1247. selector = SelectiveBuilder.get_nop_selector()
  1248. return selector
  1249. def get_grouped_by_view_native_functions(
  1250. native_functions: Sequence[NativeFunction],
  1251. ) -> Sequence[NativeFunction | NativeFunctionsViewGroup]:
  1252. def maybe_create_view_group(
  1253. d: dict[ViewSchemaKind | SchemaKind, NativeFunction],
  1254. ) -> list[NativeFunction | NativeFunctionsViewGroup]:
  1255. funcs: list[NativeFunction | NativeFunctionsViewGroup] = []
  1256. if ViewSchemaKind.aliasing in d:
  1257. view = d.pop(ViewSchemaKind.aliasing)
  1258. view_inplace = d.pop(ViewSchemaKind.aliasing_inplace, None)
  1259. view_copy = d.pop(SchemaKind.functional, None)
  1260. funcs.append(
  1261. NativeFunctionsViewGroup(
  1262. view=view,
  1263. view_copy=view_copy,
  1264. view_inplace=view_inplace,
  1265. )
  1266. )
  1267. # Take the remaining functions that weren't part of the view group
  1268. # and emit them separately
  1269. funcs.extend(d.values())
  1270. return funcs
  1271. grouped_by_views: dict[
  1272. FunctionSchema, dict[SchemaKind | ViewSchemaKind, NativeFunction]
  1273. ] = defaultdict(dict)
  1274. for f in native_functions:
  1275. schema = f.func.view_signature()
  1276. view_kind: ViewSchemaKind = f.view_schema_kind
  1277. # We need to group up ops relevant to the same "view", consisting of:
  1278. # view op (ViewSchemaKind.aliasing)
  1279. # view_inplace op (ViewSchemaKind.aliasing_inplace)
  1280. # view_copy op (SchemaKind.functional)
  1281. if view_kind == ViewSchemaKind.non_aliasing:
  1282. kind = f.func.kind()
  1283. if kind in grouped_by_views[schema]:
  1284. raise AssertionError(
  1285. f"Duplicate schema kind {kind} in {grouped_by_views[schema].keys()}"
  1286. )
  1287. grouped_by_views[schema][kind] = f
  1288. else:
  1289. if view_kind in grouped_by_views[schema]:
  1290. raise AssertionError(
  1291. f"{view_kind} already in {grouped_by_views[schema].keys()}"
  1292. )
  1293. grouped_by_views[schema][view_kind] = f
  1294. return list(concatMap(maybe_create_view_group, grouped_by_views.values()))
  1295. def get_grouped_native_functions(
  1296. native_functions: Sequence[NativeFunction],
  1297. ) -> Sequence[NativeFunction | NativeFunctionsGroup]:
  1298. def flatten_pre_group(
  1299. d: dict[SchemaKind, NativeFunction],
  1300. ) -> Sequence[NativeFunction | NativeFunctionsGroup]:
  1301. r = NativeFunctionsGroup.from_dict(d)
  1302. if r is None:
  1303. # Invariant: any NativeFunctions that are code-generated
  1304. # should have been grouped into NativeFunctionsGroup objects
  1305. if any("generated" in f.tags for f in d.values()):
  1306. raise AssertionError(
  1307. "Generated NativeFunctions should have been grouped into "
  1308. f"NativeFunctionsGroup objects: {list(d.values())}"
  1309. )
  1310. return list(d.values())
  1311. else:
  1312. return [r]
  1313. # TODO: how come ValuesView isn't a Sequence lol
  1314. pre_grouped_native_functions = pre_group_native_functions(native_functions)
  1315. return list(
  1316. concatMap(flatten_pre_group, list(pre_grouped_native_functions.values()))
  1317. )
  1318. def get_ns_grouped_kernels(
  1319. *,
  1320. grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
  1321. backend_indices: dict[DispatchKey, BackendIndex],
  1322. native_function_decl_gen: Callable[
  1323. [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
  1324. ] = dest.compute_native_function_declaration,
  1325. ) -> dict[str, list[str]]:
  1326. ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
  1327. for f in grouped_native_functions:
  1328. native_function_namespaces = set()
  1329. dispatch_keys = set()
  1330. for dispatch_key, backend_idx in backend_indices.items():
  1331. backend_metadata = backend_idx.get_kernel(f)
  1332. if backend_metadata:
  1333. namespace = backend_metadata.cpp_namespace
  1334. dispatch_keys.add(dispatch_key)
  1335. native_function_namespaces.add(namespace)
  1336. else:
  1337. namespace = DEFAULT_KERNEL_NAMESPACE
  1338. if len(native_function_namespaces) > 1:
  1339. raise AssertionError(
  1340. f"Codegen only supports one namespace per operator, "
  1341. f"got {native_function_namespaces} from {dispatch_keys}"
  1342. )
  1343. ns_grouped_kernels[namespace].extend(
  1344. native_function_decl_gen(f, backend_idx)
  1345. )
  1346. return ns_grouped_kernels
  1347. def get_native_function_declarations_from_ns_grouped_kernels(
  1348. *,
  1349. ns_grouped_kernels: dict[str, list[str]],
  1350. ) -> list[str]:
  1351. declarations: list[str] = []
  1352. newline = "\n"
  1353. for namespace, kernels in ns_grouped_kernels.items():
  1354. ns_helper = NamespaceHelper(
  1355. namespace_str=namespace,
  1356. entity_name="",
  1357. max_level=4,
  1358. )
  1359. # Convert to a set first to remove duplicate kernel names. Backends are
  1360. # allowed to repeat kernel names; only generate the declaration once!
  1361. ordered_kernels = list(OrderedDict.fromkeys(kernels))
  1362. declarations.extend(
  1363. f"""
  1364. {ns_helper.prologue}
  1365. {newline.join(ordered_kernels)}
  1366. {ns_helper.epilogue}
  1367. """.split(newline)
  1368. )
  1369. return declarations
  1370. # Return native function declarations grouped by their namespaces.
  1371. def get_native_function_declarations(
  1372. *,
  1373. grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
  1374. backend_indices: dict[DispatchKey, BackendIndex],
  1375. native_function_decl_gen: Callable[
  1376. [NativeFunctionsGroup | NativeFunction, BackendIndex], list[str]
  1377. ] = dest.compute_native_function_declaration,
  1378. ) -> list[str]:
  1379. """
  1380. Generate kernel declarations, in `NativeFunction(s).h`.
  1381. :param grouped_native_functions: a sequence of `NativeFunction` or `NativeFunctionGroup`.
  1382. :param backend_indices: kernel collections grouped by dispatch key.
  1383. :param native_function_decl_gen: callable to generate kernel declaration for each `NativeFunction`.
  1384. :return: a list of string, from the string with all declarations, grouped by namespaces, split by newline.
  1385. """
  1386. ns_grouped_kernels = get_ns_grouped_kernels(
  1387. grouped_native_functions=grouped_native_functions,
  1388. backend_indices=backend_indices,
  1389. native_function_decl_gen=native_function_decl_gen,
  1390. )
  1391. return get_native_function_declarations_from_ns_grouped_kernels(
  1392. ns_grouped_kernels=ns_grouped_kernels
  1393. )
  1394. def get_kernel_namespace(
  1395. *, f: NativeFunction | NativeFunctionsGroup, backend_idx: BackendIndex
  1396. ) -> str:
  1397. backend_metadata = backend_idx.get_kernel(f)
  1398. if backend_metadata and "::native" not in backend_metadata.cpp_namespace:
  1399. func_name = (
  1400. f.func.name if isinstance(f, NativeFunction) else f.functional.func.name
  1401. )
  1402. raise AssertionError(
  1403. f"The kernel for function {func_name} "
  1404. f"with dispatch key {backend_idx.dispatch_key} "
  1405. f"has a namespace {backend_metadata.cpp_namespace} and it's not ending with '::native'."
  1406. )
  1407. return (
  1408. backend_metadata.cpp_namespace if backend_metadata else DEFAULT_KERNEL_NAMESPACE
  1409. )
  1410. # Return native function definitions grouped by dispatch key and custom namespace.
  1411. # Used in RegisterDispatchKey.cpp and etc.
  1412. def get_native_function_definitions(
  1413. *,
  1414. fm: FileManager,
  1415. grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
  1416. dispatch_key: DispatchKey,
  1417. backend_idx: BackendIndex,
  1418. selector: SelectiveBuilder,
  1419. rocm: bool,
  1420. symint: bool,
  1421. skip_dispatcher_op_registration: bool,
  1422. gen_dispatch_helpers: bool,
  1423. ) -> list[str]:
  1424. definitions: list[str] = []
  1425. ns_definitions: dict[str, list[str]] = defaultdict(list)
  1426. anonymous_definitions: dict[str, list[str]] = defaultdict(list)
  1427. registrations: dict[str, dict[str, list[str]]] = defaultdict(dict)
  1428. newline = "\n"
  1429. ns_gen = dest.RegisterDispatchKey(
  1430. backend_idx,
  1431. Target.NAMESPACED_DEFINITION,
  1432. selector,
  1433. rocm=rocm,
  1434. symint=symint,
  1435. class_method_name=None,
  1436. skip_dispatcher_op_registration=skip_dispatcher_op_registration,
  1437. )
  1438. anonymous_gen = dest.RegisterDispatchKey(
  1439. backend_idx,
  1440. Target.ANONYMOUS_DEFINITION,
  1441. selector,
  1442. rocm=rocm,
  1443. symint=symint,
  1444. class_method_name=None,
  1445. skip_dispatcher_op_registration=skip_dispatcher_op_registration,
  1446. )
  1447. reg_gen = dest.RegisterDispatchKey(
  1448. backend_idx,
  1449. Target.REGISTRATION,
  1450. selector,
  1451. rocm=rocm,
  1452. symint=symint,
  1453. class_method_name=None,
  1454. skip_dispatcher_op_registration=skip_dispatcher_op_registration,
  1455. )
  1456. for f in grouped_native_functions:
  1457. kernel_namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
  1458. "::native", ""
  1459. )
  1460. ns_definitions[kernel_namespace].extend(
  1461. ns_gen(f),
  1462. )
  1463. anonymous_definitions[kernel_namespace].extend(
  1464. anonymous_gen(f),
  1465. )
  1466. namespace = (
  1467. f.namespace if isinstance(f, NativeFunction) else f.functional.namespace
  1468. )
  1469. if namespace not in registrations[kernel_namespace]:
  1470. registrations[kernel_namespace] = defaultdict(list)
  1471. registrations[kernel_namespace][namespace].extend(
  1472. reg_gen(f),
  1473. )
  1474. for kernel_namespace in ns_definitions:
  1475. if len(ns_definitions[kernel_namespace]) == 0:
  1476. continue
  1477. ns_helper = NamespaceHelper(namespace_str=kernel_namespace)
  1478. registration_body = ""
  1479. for namespace in registrations[kernel_namespace]:
  1480. if not registrations[kernel_namespace][namespace]:
  1481. continue
  1482. registration_body += f"""
  1483. TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
  1484. {newline.join(registrations[kernel_namespace][namespace])}
  1485. }}"""
  1486. definitions.extend(
  1487. fm.substitute_with_template(
  1488. "RegisterDispatchDefinitions.ini",
  1489. lambda: {
  1490. "ns_prologue": ns_helper.prologue,
  1491. "ns_epilogue": ns_helper.epilogue,
  1492. "dispatch_anonymous_definitions": anonymous_definitions[
  1493. kernel_namespace
  1494. ],
  1495. "static_init_dispatch_registrations": ""
  1496. if skip_dispatcher_op_registration
  1497. else registration_body,
  1498. "deferred_dispatch_registrations": "",
  1499. "dispatch_namespace": dispatch_key.lower(),
  1500. "dispatch_namespaced_definitions": ns_definitions[kernel_namespace],
  1501. },
  1502. ).split(newline)
  1503. )
  1504. return definitions
  1505. # Return native function declarations grouped by dispatch key and custom namespace.
  1506. # Used in CPUFunctions_inl.h and etc.
  1507. def get_namespaced_declaration(
  1508. *,
  1509. grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
  1510. dispatch_key: DispatchKey,
  1511. backend_idx: BackendIndex,
  1512. selector: SelectiveBuilder,
  1513. rocm: bool,
  1514. symint: bool,
  1515. ) -> list[str]:
  1516. declarations: list[str] = []
  1517. ns_grouped_kernels: dict[str, list[str]] = defaultdict(list)
  1518. newline = "\n"
  1519. func = dest.RegisterDispatchKey(
  1520. backend_idx,
  1521. Target.NAMESPACED_DECLARATION,
  1522. selector,
  1523. rocm=rocm,
  1524. class_method_name=None,
  1525. skip_dispatcher_op_registration=False,
  1526. symint=symint,
  1527. )
  1528. for f in grouped_native_functions:
  1529. namespace = get_kernel_namespace(f=f, backend_idx=backend_idx).replace(
  1530. "native", dispatch_key.lower()
  1531. )
  1532. ns_grouped_kernels[namespace].extend(
  1533. func(f),
  1534. )
  1535. for namespace, kernels in ns_grouped_kernels.items():
  1536. if len(kernels) == 0:
  1537. continue
  1538. ns_helper = NamespaceHelper(
  1539. namespace_str=namespace, entity_name="", max_level=3
  1540. )
  1541. ordered_kernels = list(OrderedDict.fromkeys(kernels))
  1542. declarations.extend(
  1543. f"""
  1544. {ns_helper.prologue}
  1545. {newline.join(ordered_kernels)}
  1546. {ns_helper.epilogue}
  1547. """.split(newline)
  1548. )
  1549. return declarations
  1550. # Return native function schema registration code for aten and other namespaces.
  1551. def get_native_function_schema_registrations(
  1552. *,
  1553. native_functions: Sequence[NativeFunction],
  1554. schema_selector: SelectiveBuilder,
  1555. ) -> tuple[list[str], str]:
  1556. ns_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
  1557. for native_function in native_functions:
  1558. ns_native_functions[native_function.namespace].append(native_function)
  1559. schema_registrations = ""
  1560. aten_schema_registrations = []
  1561. custom_namespace = None
  1562. for namespace, funcs in ns_native_functions.items():
  1563. schema_registrations_body = list(
  1564. mapMaybe(RegisterSchema(schema_selector), funcs)
  1565. )
  1566. # NB: we have to separate aten namespace registration from other namespaces,
  1567. # because in the template we hardcoded an operator for ATen already.
  1568. if namespace == "aten":
  1569. aten_schema_registrations = schema_registrations_body
  1570. else:
  1571. custom_namespace = namespace
  1572. tab = "\t"
  1573. # if the namespace is predefined, we should use define a library fragment
  1574. # instead of a new library
  1575. torch_library_macro = (
  1576. "TORCH_LIBRARY_FRAGMENT"
  1577. if namespace in FRAGMENT_NAMESPACES
  1578. else "TORCH_LIBRARY"
  1579. )
  1580. schema_registrations += f"""
  1581. {torch_library_macro}({custom_namespace}, m) {{
  1582. {tab.join(schema_registrations_body)}
  1583. }};"""
  1584. return (aten_schema_registrations, schema_registrations)
  1585. def gen_aggregated_headers(
  1586. *,
  1587. native_functions: Sequence[NativeFunction],
  1588. grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
  1589. structured_native_functions: Sequence[NativeFunctionsGroup],
  1590. static_dispatch_idx: list[BackendIndex],
  1591. selector: SelectiveBuilder,
  1592. backend_indices: dict[DispatchKey, BackendIndex],
  1593. cpu_fm: FileManager,
  1594. device_fms: dict[str, FileManager],
  1595. functions_keys: set[DispatchKey],
  1596. dispatch_keys: Sequence[DispatchKey],
  1597. rocm: bool,
  1598. ) -> None:
  1599. # Buck doesn't support dynamic output files, so we aggregate all operator
  1600. # headers into a single file
  1601. cpu_fm.write(
  1602. "NativeMetaFunctions.h",
  1603. lambda: {
  1604. "NativeMetaFunctions_includes": [],
  1605. "NativeMetaFunctions_declarations": list(
  1606. mapMaybe(compute_meta_function_declaration, structured_native_functions)
  1607. ),
  1608. },
  1609. )
  1610. method_native_functions = [
  1611. fn for fn in native_functions if Variant.method in fn.variants
  1612. ]
  1613. non_method_native_functions = [
  1614. fn for fn in native_functions if fn not in method_native_functions
  1615. ]
  1616. cpu_fm.write(
  1617. "MethodOperators.h",
  1618. lambda: {
  1619. "MethodOperators_includes": [],
  1620. "MethodOperators_declarations": list(
  1621. mapMaybe(
  1622. ComputeOperators(
  1623. Target.DECLARATION,
  1624. static_dispatch_backend_indices=static_dispatch_idx,
  1625. ),
  1626. method_native_functions,
  1627. )
  1628. ),
  1629. },
  1630. )
  1631. cpu_fm.write(
  1632. "Operators.h",
  1633. lambda: {
  1634. "Operators_includes": ["#include <ATen/MethodOperators.h>"],
  1635. "Operators_declarations": list(
  1636. mapMaybe(
  1637. ComputeOperators(
  1638. Target.DECLARATION,
  1639. static_dispatch_backend_indices=static_dispatch_idx,
  1640. ),
  1641. non_method_native_functions,
  1642. )
  1643. ),
  1644. },
  1645. )
  1646. cpu_fm.write(
  1647. "Functions.h",
  1648. lambda: {
  1649. "static_dispatch_extra_headers": static_dispatch_extra_headers(
  1650. static_dispatch_idx
  1651. ),
  1652. "Functions_includes": ["#include <ATen/Operators.h>"],
  1653. "Functions_declarations": list(
  1654. mapMaybe(
  1655. ComputeFunction(),
  1656. native_functions,
  1657. )
  1658. ),
  1659. },
  1660. )
  1661. declarations = get_native_function_declarations(
  1662. grouped_native_functions=grouped_native_functions,
  1663. backend_indices=backend_indices,
  1664. )
  1665. cpu_fm.write(
  1666. "NativeFunctions.h",
  1667. lambda: {
  1668. "NativeFunctions_includes": ["#include <ATen/NativeMetaFunctions.h>"],
  1669. "NativeFunctions_declarations": declarations,
  1670. },
  1671. )
  1672. for dispatch_key in dispatch_keys:
  1673. fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
  1674. if dispatch_key in functions_keys:
  1675. inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
  1676. fm.write_with_template(
  1677. f"{dispatch_key}Functions.h",
  1678. "DispatchKeyFunctions.h",
  1679. lambda: {
  1680. "dispatch_key": str(dispatch_key),
  1681. "inline_headers": inl_headers,
  1682. },
  1683. )
  1684. fm.write_with_template(
  1685. f"{dispatch_key}Functions_inl.h",
  1686. "DispatchKeyFunctions_inl.h",
  1687. lambda: {
  1688. "DispatchKeyFunctions_inl_includes": [],
  1689. "dispatch_namespace": dispatch_key.lower(),
  1690. "dispatch_namespaced_declarations": get_namespaced_declaration(
  1691. grouped_native_functions=grouped_native_functions,
  1692. dispatch_key=dispatch_key,
  1693. backend_idx=backend_indices[dispatch_key],
  1694. selector=selector,
  1695. rocm=rocm,
  1696. symint=True,
  1697. ),
  1698. },
  1699. )
  1700. del fm
  1701. def gen_per_operator_headers(
  1702. *,
  1703. native_functions: Sequence[NativeFunction],
  1704. grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
  1705. static_dispatch_idx: list[BackendIndex],
  1706. selector: SelectiveBuilder,
  1707. backend_indices: dict[DispatchKey, BackendIndex],
  1708. cpu_fm: FileManager,
  1709. device_fms: dict[str, FileManager],
  1710. ops_fm: FileManager,
  1711. functions_keys: set[DispatchKey],
  1712. dispatch_keys: Sequence[DispatchKey],
  1713. rocm: bool,
  1714. ) -> None:
  1715. # For CMake builds, split operator declarations into separate headers in
  1716. # the ATen/ops folder to split up header dependencies
  1717. functions_by_root_name: dict[str, list[NativeFunction]] = defaultdict(list)
  1718. for fn in native_functions:
  1719. functions_by_root_name[fn.root_name].append(fn)
  1720. grouped_functions_by_root_name: dict[
  1721. str, list[NativeFunction | NativeFunctionsGroup]
  1722. ] = defaultdict(list)
  1723. for group in grouped_native_functions:
  1724. name = group.root_name
  1725. grouped_functions_by_root_name[name].append(group)
  1726. for name, functions in functions_by_root_name.items():
  1727. ops_fm.write_with_template(
  1728. f"{name}_ops.h",
  1729. "Operator.h",
  1730. lambda: {
  1731. "declarations": list(
  1732. mapMaybe(
  1733. ComputeOperators(
  1734. Target.DECLARATION,
  1735. static_dispatch_backend_indices=static_dispatch_idx,
  1736. ),
  1737. functions,
  1738. )
  1739. ),
  1740. },
  1741. )
  1742. ops_fm.write_with_template(
  1743. f"{name}.h",
  1744. "Function.h",
  1745. lambda: {
  1746. "static_dispatch_ops_headers": list(
  1747. mapMaybe(
  1748. lambda fn: static_dispatch_ops_header(
  1749. fn, backend_index=static_dispatch_idx
  1750. ),
  1751. functions,
  1752. )
  1753. ),
  1754. "operator_includes": f"#include <ATen/ops/{name}_ops.h>",
  1755. "function_definitions": list(
  1756. mapMaybe(
  1757. ComputeFunction(),
  1758. functions,
  1759. )
  1760. ),
  1761. },
  1762. )
  1763. grouped_functions = grouped_functions_by_root_name.get(name, [])
  1764. structured_functions = [
  1765. fn
  1766. for fn in grouped_functions
  1767. if isinstance(fn, NativeFunctionsGroup) and fn.structured
  1768. ]
  1769. is_structured = len(structured_functions) > 0
  1770. if is_structured:
  1771. ops_fm.write_with_template(
  1772. f"{name}_meta.h",
  1773. "NativeMetaFunction.h",
  1774. lambda: {
  1775. "meta_function_declarations": list(
  1776. mapMaybe(
  1777. compute_meta_function_declaration, structured_functions
  1778. )
  1779. ),
  1780. },
  1781. )
  1782. declarations = get_native_function_declarations(
  1783. grouped_native_functions=grouped_functions,
  1784. backend_indices=backend_indices,
  1785. native_function_decl_gen=dest.compute_native_function_declaration,
  1786. )
  1787. ops_fm.write_with_template(
  1788. f"{name}_native.h",
  1789. "NativeFunction.h",
  1790. lambda: {
  1791. "extra_includes": (
  1792. f"#include <ATen/ops/{name}_meta.h>" if is_structured else []
  1793. ),
  1794. "native_function_declarations": declarations,
  1795. },
  1796. )
  1797. for category, suffix in [
  1798. ("Functions", ""),
  1799. ("Operators", "_ops"),
  1800. ("NativeMetaFunctions", "_meta"),
  1801. ("NativeFunctions", "_native"),
  1802. ]:
  1803. cpu_fm.write(
  1804. f"{category}.h",
  1805. lambda: {
  1806. f"{category}_includes": [
  1807. f"#include <ATen/ops/{name}{suffix}.h>"
  1808. for name in sorted(functions_by_root_name.keys())
  1809. ],
  1810. f"{category}_declarations": [],
  1811. },
  1812. )
  1813. for dispatch_key in dispatch_keys:
  1814. if dispatch_key not in functions_keys:
  1815. continue
  1816. dispatch_namespace = dispatch_key.lower()
  1817. dispatch_names = []
  1818. for name, functions in functions_by_root_name.items():
  1819. grouped_functions = grouped_functions_by_root_name.get(name, [])
  1820. declarations = list(
  1821. concatMap(
  1822. dest.RegisterDispatchKey(
  1823. backend_indices[dispatch_key],
  1824. Target.NAMESPACED_DECLARATION,
  1825. selector,
  1826. rocm=rocm,
  1827. symint=True,
  1828. class_method_name=None,
  1829. skip_dispatcher_op_registration=False,
  1830. ),
  1831. grouped_functions,
  1832. )
  1833. )
  1834. if len(declarations) == 0:
  1835. continue
  1836. dispatch_names.append(name)
  1837. ops_fm.write_with_template(
  1838. f"{name}_{dispatch_namespace}_dispatch.h",
  1839. "DispatchKeyFunction.h",
  1840. lambda: {
  1841. "dispatch_namespace": dispatch_namespace,
  1842. "dispatch_namespaced_declarations": declarations,
  1843. },
  1844. )
  1845. fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
  1846. inl_headers = f"#include <ATen/{dispatch_key}Functions_inl.h>"
  1847. fm.write_with_template(
  1848. f"{dispatch_key}Functions.h",
  1849. "DispatchKeyFunctions.h",
  1850. lambda: {
  1851. "dispatch_key": str(dispatch_key),
  1852. "inline_headers": inl_headers,
  1853. },
  1854. )
  1855. fm.write_with_template(
  1856. f"{dispatch_key}Functions_inl.h",
  1857. "DispatchKeyFunctions_inl.h",
  1858. lambda: {
  1859. "dispatch_namespace": dispatch_namespace,
  1860. "DispatchKeyFunctions_inl_includes": [
  1861. f"#include <ATen/ops/{name}_{dispatch_namespace}_dispatch.h>"
  1862. for name in sorted(dispatch_names)
  1863. ],
  1864. "dispatch_namespaced_declarations": [],
  1865. },
  1866. )
  1867. del fm
  1868. cpu_fm.write(
  1869. "MethodOperators.h",
  1870. lambda: {
  1871. "MethodOperators_includes": sorted(
  1872. f"#include <ATen/ops/{name}_ops.h>"
  1873. for name, functions in functions_by_root_name.items()
  1874. if any(Variant.method in fn.variants for fn in functions)
  1875. ),
  1876. "MethodOperators_declarations": [],
  1877. },
  1878. )
  1879. def gen_headers(
  1880. *,
  1881. native_functions: Sequence[NativeFunction],
  1882. valid_tags: set[str],
  1883. grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
  1884. structured_native_functions: Sequence[NativeFunctionsGroup],
  1885. static_dispatch_idx: list[BackendIndex],
  1886. selector: SelectiveBuilder,
  1887. backend_indices: dict[DispatchKey, BackendIndex],
  1888. core_fm: FileManager,
  1889. cpu_fm: FileManager,
  1890. device_fms: dict[str, FileManager],
  1891. ops_fm: FileManager,
  1892. dispatch_keys: Sequence[DispatchKey],
  1893. functions_keys: set[DispatchKey],
  1894. rocm: bool,
  1895. per_operator_headers: bool,
  1896. ) -> None:
  1897. if per_operator_headers:
  1898. gen_per_operator_headers(
  1899. native_functions=native_functions,
  1900. grouped_native_functions=grouped_native_functions,
  1901. static_dispatch_idx=static_dispatch_idx,
  1902. selector=selector,
  1903. backend_indices=backend_indices,
  1904. cpu_fm=cpu_fm,
  1905. device_fms=device_fms,
  1906. ops_fm=ops_fm,
  1907. dispatch_keys=dispatch_keys,
  1908. functions_keys=functions_keys,
  1909. rocm=rocm,
  1910. )
  1911. else:
  1912. gen_aggregated_headers(
  1913. native_functions=native_functions,
  1914. grouped_native_functions=grouped_native_functions,
  1915. structured_native_functions=structured_native_functions,
  1916. static_dispatch_idx=static_dispatch_idx,
  1917. selector=selector,
  1918. backend_indices=backend_indices,
  1919. cpu_fm=cpu_fm,
  1920. device_fms=device_fms,
  1921. dispatch_keys=dispatch_keys,
  1922. functions_keys=functions_keys,
  1923. rocm=rocm,
  1924. )
  1925. core_fm.write(
  1926. "TensorBody.h",
  1927. lambda: {
  1928. "tensor_method_declarations": list(
  1929. mapMaybe(
  1930. ComputeTensorMethod(
  1931. target=Target.DECLARATION,
  1932. static_dispatch_backend_indices=static_dispatch_idx,
  1933. ),
  1934. native_functions,
  1935. )
  1936. ),
  1937. "tensor_method_definitions": list(
  1938. mapMaybe(
  1939. ComputeTensorMethod(
  1940. target=Target.DEFINITION,
  1941. static_dispatch_backend_indices=static_dispatch_idx,
  1942. ),
  1943. native_functions,
  1944. )
  1945. ),
  1946. },
  1947. )
  1948. cpu_fm.write(
  1949. "RedispatchFunctions.h",
  1950. lambda: {
  1951. "function_redispatch_definitions": list(
  1952. mapMaybe(ComputeRedispatchFunction(), native_functions)
  1953. ),
  1954. },
  1955. )
  1956. cpu_fm.write(
  1957. "RegistrationDeclarations.h",
  1958. lambda: {
  1959. "registration_declarations": [
  1960. compute_registration_declarations(f, backend_indices)
  1961. for f in native_functions
  1962. ],
  1963. },
  1964. )
  1965. cpu_fm.write(
  1966. "VmapGeneratedPlumbing.h", lambda: gen_all_vmap_plumbing(native_functions)
  1967. )
  1968. def gen_aten_interned_strings() -> dict[str, str]:
  1969. attrs: set[str] = set() # All function argument names
  1970. names = set() # All ATen function names
  1971. for func in native_functions:
  1972. names.add(str(func.func.name.name))
  1973. # Some operators don't have a functional variant but we still create a
  1974. # symbol without the underscore
  1975. names.add(func.func.name.name.base)
  1976. attrs.update(arg.name for arg in func.func.schema_order_arguments())
  1977. # These are keywords in C++, so aren't valid symbol names
  1978. # https://en.cppreference.com/w/cpp/language/operator_alternative
  1979. names -= {
  1980. "and",
  1981. "and_eq",
  1982. "bitand",
  1983. "bitor",
  1984. "compl",
  1985. "not",
  1986. "not_eq",
  1987. "or",
  1988. "or_eq",
  1989. "xor",
  1990. "xor_eq",
  1991. }
  1992. return {
  1993. "aten_symbols": " \\\n".join(
  1994. [f"_(aten, {name})" for name in sorted(names)]
  1995. ),
  1996. "attr_symbols": " \\\n".join(
  1997. [f"_(attr, {name})" for name in sorted(attrs)]
  1998. ),
  1999. }
  2000. core_fm.write("aten_interned_strings.h", gen_aten_interned_strings)
  2001. def gen_tags_enum() -> dict[str, str]:
  2002. return {"enum_of_valid_tags": (",\n".join(sorted(valid_tags)))}
  2003. core_fm.write("enum_tag.h", gen_tags_enum)
  2004. def gen_source_files(
  2005. *,
  2006. native_functions: Sequence[NativeFunction],
  2007. grouped_native_functions: Sequence[NativeFunction | NativeFunctionsGroup],
  2008. structured_native_functions: Sequence[NativeFunctionsGroup],
  2009. view_groups: Sequence[NativeFunctionsViewGroup],
  2010. selector: SelectiveBuilder,
  2011. static_dispatch_idx: list[BackendIndex],
  2012. backend_indices: dict[DispatchKey, BackendIndex],
  2013. aoti_fm: FileManager,
  2014. core_fm: FileManager,
  2015. cpu_vec_fm: FileManager,
  2016. cpu_fm: FileManager,
  2017. device_fms: dict[str, FileManager],
  2018. dispatch_keys: Sequence[DispatchKey],
  2019. functions_keys: set[DispatchKey],
  2020. rocm: bool,
  2021. force_schema_registration: bool,
  2022. per_operator_headers: bool,
  2023. skip_dispatcher_op_registration: bool,
  2024. update_aoti_c_shim: bool,
  2025. aoti_backends: set[DispatchKey | None],
  2026. extend_aoti_c_shim: bool,
  2027. ) -> None:
  2028. extra_cuda_headers = """\
  2029. #include <c10/cuda/CUDAGuard.h>
  2030. #include <ATen/cuda/ATenCUDAGeneral.h>
  2031. #include <ATen/cuda/CUDADevice.h>
  2032. #include <ATen/cuda/CUDAContext.h>"""
  2033. if rocm:
  2034. extra_cuda_headers = """\
  2035. #include <c10/hip/HIPGuard.h>
  2036. #include <ATen/hip/ATenHIPGeneral.h>
  2037. #include <ATen/hip/HIPDevice.h>
  2038. #include <ATen/hip/HIPContext.h>"""
  2039. for dispatch_key in dispatch_keys:
  2040. fm = file_manager_from_dispatch_key(dispatch_key, device_fms, cpu_fm)
  2041. if per_operator_headers:
  2042. def operator_headers() -> list[str]:
  2043. headers = []
  2044. for g in grouped_native_functions:
  2045. is_registered = False
  2046. if backend_index.has_kernel(g):
  2047. is_registered = True
  2048. # The above has_kernel test on a group will only test for
  2049. # the existence of out dispatch, because that's how
  2050. # structured kernels work. But sometimes functions can be
  2051. # grouped but not be structured, and then you need to check
  2052. # each individual piece, as they may have manual dispatch
  2053. # entries.
  2054. elif isinstance(g, NativeFunctionsGroup) and any(
  2055. backend_index.has_kernel(fn) for fn in g.functions()
  2056. ):
  2057. is_registered = True
  2058. # TODO: this condition is a bit questionable
  2059. # (It has to do with the fact that structured kernels get generated kernels
  2060. # to the Meta + CompositeExplicitAutogradNonFunctional keys).
  2061. elif g.structured and dispatch_key in (
  2062. DispatchKey.Meta,
  2063. DispatchKey.CompositeExplicitAutogradNonFunctional,
  2064. ):
  2065. is_registered = True
  2066. if not is_registered:
  2067. continue
  2068. headers.append(f"#include <ATen/ops/{g.root_name}_native.h>")
  2069. if (
  2070. dispatch_key
  2071. == DispatchKey.CompositeExplicitAutogradNonFunctional
  2072. ):
  2073. headers.append(f"#include <ATen/ops/{g.root_name}.h>")
  2074. if dispatch_key in functions_keys:
  2075. headers.append(
  2076. f"#include <ATen/ops/{g.root_name}_{dispatch_namespace}_dispatch.h>"
  2077. )
  2078. return sorted(set(headers))
  2079. else:
  2080. def operator_headers() -> list[str]:
  2081. headers = ["#include <ATen/NativeFunctions.h>"]
  2082. if dispatch_key == DispatchKey.CompositeExplicitAutogradNonFunctional:
  2083. headers.append("#include <ATen/Functions.h>")
  2084. if dispatch_key in functions_keys:
  2085. headers.append(f"#include <ATen/{dispatch_key!s}Functions.h>")
  2086. return headers
  2087. backend_index = backend_indices[dispatch_key]
  2088. ns_grouped_native_functions = defaultdict(list)
  2089. for grouped_native_function in grouped_native_functions:
  2090. namespace = (
  2091. grouped_native_function.namespace
  2092. if isinstance(grouped_native_function, NativeFunction)
  2093. else grouped_native_function.functional.namespace
  2094. )
  2095. ns_grouped_native_functions[namespace].append(grouped_native_function)
  2096. dispatch_namespace = str(dispatch_key).lower()
  2097. # CompositeImplicitAutogradNestdTensor does not currently user the helpers generated
  2098. # compilation will fail when `-Werror=unused-function` flag is set
  2099. gen_dispatch_helpers: bool = (
  2100. dispatch_key != DispatchKey.CompositeImplicitAutogradNestedTensor
  2101. )
  2102. register_dispatch_key_base_env = {
  2103. "extra_cuda_headers": extra_cuda_headers
  2104. if is_cuda_dispatch_key(dispatch_key)
  2105. else "",
  2106. "external_backend_headers": "",
  2107. "dispatch_headers": dest.gen_registration_headers(
  2108. backend_index, per_operator_headers, rocm
  2109. ),
  2110. # ops_headers *could* be sharded, but doesn't seem necessary?
  2111. "ops_headers": operator_headers(),
  2112. "dispatch_helpers": (
  2113. dest.gen_registration_helpers(backend_index)
  2114. if gen_dispatch_helpers
  2115. else []
  2116. ),
  2117. }
  2118. def register_dispatch_key_env_callable(
  2119. gnf: NativeFunction | NativeFunctionsGroup,
  2120. ) -> dict[str, list[str]]:
  2121. return {
  2122. "dispatch_definitions": get_native_function_definitions(
  2123. fm=fm, # noqa: F821
  2124. grouped_native_functions=[gnf],
  2125. dispatch_key=dispatch_key,
  2126. backend_idx=backend_index,
  2127. selector=selector,
  2128. rocm=rocm,
  2129. symint=True,
  2130. skip_dispatcher_op_registration=skip_dispatcher_op_registration,
  2131. gen_dispatch_helpers=gen_dispatch_helpers,
  2132. )
  2133. }
  2134. fm.write_sharded_with_template(
  2135. f"Register{dispatch_key}.cpp",
  2136. "RegisterDispatchKey.cpp",
  2137. grouped_native_functions,
  2138. key_fn=lambda x: x.root_name,
  2139. env_callable=register_dispatch_key_env_callable,
  2140. num_shards=4 if dispatch_key == DispatchKey.CPU else 1,
  2141. base_env=register_dispatch_key_base_env,
  2142. sharded_keys={"dispatch_definitions"},
  2143. )
  2144. for g in structured_native_functions:
  2145. if not g.out.ufunc_inner_loop or not is_ufunc_dispatch_key(dispatch_key):
  2146. continue
  2147. name = g.functional.func.name.name
  2148. if dispatch_key is DispatchKey.CPU:
  2149. if fm is not cpu_fm:
  2150. raise AssertionError("Expected fm to be cpu_fm for DispatchKey.CPU")
  2151. fm.write_with_template(
  2152. f"UfuncCPU_{name}.cpp",
  2153. "UfuncCPU.cpp",
  2154. lambda: {
  2155. "meta_declaration": compute_meta_function_declaration(g),
  2156. "native_declaration": dest.compute_native_function_declaration(
  2157. g, backend_indices[dispatch_key]
  2158. ),
  2159. "native_definitions": dest.compute_ufunc_cpu(g),
  2160. },
  2161. )
  2162. cpu_vec_fm.write_with_template(
  2163. f"UfuncCPUKernel_{name}.cpp",
  2164. "UfuncCPUKernel.cpp",
  2165. lambda: {
  2166. "name": name,
  2167. "native_definitions": dest.compute_ufunc_cpu_kernel(g),
  2168. },
  2169. )
  2170. elif dispatch_key is DispatchKey.CUDA:
  2171. cuda_headers = "#include <ATen/native/cuda/Loops.cuh>"
  2172. if rocm:
  2173. cuda_headers = "#include <ATen/native/hip/Loops.cuh>"
  2174. fm.write_with_template(
  2175. f"UfuncCUDA_{name}.cu",
  2176. "UfuncCUDA.cu",
  2177. lambda: {
  2178. "name": name,
  2179. "cuda_headers": cuda_headers,
  2180. "meta_declaration": compute_meta_function_declaration(g),
  2181. "native_declaration": dest.compute_native_function_declaration(
  2182. g, backend_indices[dispatch_key]
  2183. ),
  2184. "native_definitions": dest.compute_ufunc_cuda(g),
  2185. },
  2186. )
  2187. else:
  2188. raise AssertionError(f"unrecognized {dispatch_key} for ufunc")
  2189. del fm
  2190. gen_aoti_c_shim_files(
  2191. aoti_fm=aoti_fm,
  2192. aoti_backends=aoti_backends,
  2193. native_functions=native_functions,
  2194. backend_indices=backend_indices,
  2195. structured_native_functions=structured_native_functions,
  2196. extra_cuda_headers=extra_cuda_headers,
  2197. update_aoti_c_shim=update_aoti_c_shim,
  2198. extend_aoti_c_shim=extend_aoti_c_shim,
  2199. )
  2200. # BackendSelect is generated specially
  2201. def gen_backend_select() -> dict[str, list[str]]:
  2202. relevant_fns = [
  2203. fn for fn in native_functions if needs_backend_select(fn, selector)
  2204. ]
  2205. return {
  2206. "ops_headers": [
  2207. f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in relevant_fns
  2208. ],
  2209. "backend_select_method_definitions": list(
  2210. mapMaybe(
  2211. ComputeBackendSelect(Target.DEFINITION, selector), relevant_fns
  2212. )
  2213. ),
  2214. "backend_select_function_registrations": list(
  2215. mapMaybe(
  2216. ComputeBackendSelect(Target.REGISTRATION, selector), relevant_fns
  2217. )
  2218. ),
  2219. }
  2220. cpu_fm.write("RegisterBackendSelect.cpp", gen_backend_select)
  2221. schema_selector = selector
  2222. if force_schema_registration:
  2223. schema_selector = SelectiveBuilder.get_nop_selector()
  2224. (
  2225. aten_schema_registrations,
  2226. schema_registrations,
  2227. ) = get_native_function_schema_registrations(
  2228. native_functions=native_functions, schema_selector=schema_selector
  2229. )
  2230. cpu_fm.write(
  2231. "RegisterSchema.cpp",
  2232. lambda: {
  2233. "aten_schema_registrations": []
  2234. if skip_dispatcher_op_registration
  2235. else aten_schema_registrations,
  2236. "schema_registrations": []
  2237. if skip_dispatcher_op_registration
  2238. else schema_registrations,
  2239. },
  2240. )
  2241. def key_func(
  2242. fn: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
  2243. ) -> str:
  2244. return fn.root_name
  2245. cpu_fm.write_sharded(
  2246. "Operators.cpp",
  2247. native_functions,
  2248. key_fn=key_func,
  2249. env_callable=lambda fn: {
  2250. "operator_headers": [f"#include <ATen/ops/{fn.root_name}.h>"],
  2251. "definitions": [
  2252. ComputeOperators(
  2253. Target.DEFINITION,
  2254. static_dispatch_backend_indices=static_dispatch_idx,
  2255. )(fn)
  2256. ],
  2257. },
  2258. base_env={
  2259. "static_dispatch_extra_headers": static_dispatch_extra_headers(
  2260. static_dispatch_idx
  2261. ),
  2262. },
  2263. num_shards=5,
  2264. sharded_keys={
  2265. "operator_headers",
  2266. "definitions",
  2267. "static_dispatch_extra_headers",
  2268. },
  2269. )
  2270. cpu_fm.write("Functions.cpp", dict)
  2271. core_fm.write("TensorMethods.cpp", dict)
  2272. core_fm.write(
  2273. "ATenOpList.cpp",
  2274. lambda: {
  2275. "aten_ops": list(mapMaybe(compute_aten_op, native_functions)),
  2276. },
  2277. )
  2278. def gen_op_headers(
  2279. g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
  2280. ) -> list[str]:
  2281. if isinstance(g, NativeFunctionsViewGroup):
  2282. # view ops always get a functionalization kernel
  2283. headers = [
  2284. f"#include <ATen/ops/{g.view.root_name}_native.h>",
  2285. f"#include <ATen/ops/{g.view.root_name}_ops.h>",
  2286. ]
  2287. if g.view_copy is not None:
  2288. headers += [
  2289. f"#include <ATen/ops/{g.view_copy.root_name}_native.h>",
  2290. f"#include <ATen/ops/{g.view_copy.root_name}_ops.h>",
  2291. ]
  2292. return headers
  2293. elif isinstance(g, NativeFunctionsGroup):
  2294. headers = [
  2295. f"#include <ATen/ops/{g.functional.root_name}_native.h>",
  2296. f"#include <ATen/ops/{g.functional.root_name}_ops.h>",
  2297. f"#include <ATen/ops/{g.out.root_name}_native.h>",
  2298. f"#include <ATen/ops/{g.out.root_name}_ops.h>",
  2299. ]
  2300. if g.inplace is not None:
  2301. headers += [
  2302. f"#include <ATen/ops/{g.inplace.root_name}_native.h>",
  2303. f"#include <ATen/ops/{g.inplace.root_name}_ops.h>",
  2304. ]
  2305. if g.mutable is not None:
  2306. headers += [
  2307. f"#include <ATen/ops/{g.mutable.root_name}_native.h>",
  2308. f"#include <ATen/ops/{g.mutable.root_name}_ops.h>",
  2309. ]
  2310. return headers
  2311. else:
  2312. return [
  2313. f"#include <ATen/ops/{g.root_name}_native.h>",
  2314. f"#include <ATen/ops/{g.root_name}_ops.h>",
  2315. ]
  2316. def functionalization_env_callable(
  2317. g: NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup,
  2318. ) -> dict[str, list[str]]:
  2319. return {
  2320. "ops_headers": gen_op_headers(g),
  2321. "func_definitions": gen_functionalization_definition(
  2322. selector,
  2323. g,
  2324. ),
  2325. "func_registrations": gen_functionalization_registration(
  2326. selector,
  2327. g,
  2328. backend_indices[DispatchKey.CompositeImplicitAutograd],
  2329. ),
  2330. }
  2331. all_groups: list[
  2332. NativeFunction | NativeFunctionsGroup | NativeFunctionsViewGroup
  2333. ] = list(structured_native_functions) + list(
  2334. view_groups # type: ignore[assignment, arg-type, operator]
  2335. )
  2336. # Note: all operators that functionalization needs to handle (mutable and aliasing ops) should be grouped properly.
  2337. # The only reason we really need to deal with direct NativeFunctions here (instead of the groups) is because:
  2338. # (1) We can provide better error checking (error out if someone introduces a mutable op that doesn't obey the grouping logic)
  2339. # (2) functionalization needs to manually register CompositeImplicitAutograd kernels, which might not be grouped.
  2340. # Although this could go away long-term if we add a dedicated dispatch key for decompositions.
  2341. structured_map: dict[OperatorName, NativeFunction] = {
  2342. f.func.name: f
  2343. for f in concatMap(lambda g: list(g.functions()), structured_native_functions)
  2344. }
  2345. view_map: dict[OperatorName, NativeFunction] = {
  2346. f.func.name: f for f in concatMap(lambda g: list(g.functions()), view_groups)
  2347. }
  2348. all_groups.extend(
  2349. f
  2350. for f in native_functions
  2351. if f.func.name not in structured_map and f.func.name not in view_map
  2352. )
  2353. cpu_fm.write_sharded(
  2354. "RegisterFunctionalization.cpp",
  2355. all_groups,
  2356. key_fn=key_func,
  2357. env_callable=functionalization_env_callable,
  2358. num_shards=4,
  2359. sharded_keys={
  2360. "ops_headers",
  2361. "func_definitions",
  2362. "func_registrations",
  2363. "func_add_back_views_definitions",
  2364. "func_add_back_views_registrations",
  2365. },
  2366. )
  2367. cpu_fm.write(
  2368. "FunctionalInverses.h",
  2369. lambda: {
  2370. "view_inverse_declarations": list(
  2371. mapMaybe(
  2372. lambda g: gen_functionalization_view_inverse_declaration(
  2373. selector, g
  2374. ),
  2375. view_groups,
  2376. )
  2377. )
  2378. },
  2379. )
  2380. cpu_fm.write(
  2381. "ViewMetaClasses.h",
  2382. lambda: {
  2383. "view_meta_declarations": list(
  2384. concatMap(
  2385. lambda g: gen_functionalization_view_meta_classes_decl(selector, g),
  2386. view_groups,
  2387. )
  2388. )
  2389. },
  2390. )
  2391. cpu_fm.write(
  2392. "ViewMetaClasses.cpp",
  2393. lambda: {
  2394. "view_meta_implementations": list(
  2395. concatMap(
  2396. lambda g: gen_functionalization_view_meta_classes_impl(selector, g),
  2397. view_groups,
  2398. )
  2399. ),
  2400. "op_headers": list(concatMap(gen_op_headers, view_groups)),
  2401. },
  2402. )
  2403. # Note [view_copy NativeFunctions]
  2404. # Every view operator in native_functions.yaml that is not CompositeImplicitAutograd
  2405. # needs to have a corresponding non-aliasing {view}_copy variant.
  2406. # Backends that use functionalization and don't know how to handle aliasing ops
  2407. # are expected to implement kernels for these {view}_copy kernels instead.
  2408. # The code for {view}_copy operators in core is pretty boilerplate-heavy however,
  2409. # so we codegen the following:
  2410. # (1) A CompositeExplicitAutogradNonFunctional kernel for every {view}_copy operator.
  2411. # These are never explicitly invoked by the functionalization pass,
  2412. # but they could theoretically be called from user code (I added these kernels for completeness,
  2413. # since the ops are part of the public API).
  2414. # (2) A derivative formula for every {view}_copy operator
  2415. # {view}_copy operators can reuse the same derivative formulas as their {view} op counterparts,
  2416. # so rather than stamping all of the entries out in derivatives.yaml,
  2417. # we codegen them in.
  2418. # This is similar to how autograd codegen doesn't require inplace ops to have a derivatives.yaml entry.
  2419. cpu_fm.write(
  2420. "CompositeViewCopyKernels.cpp",
  2421. lambda: {
  2422. "ops_headers": [
  2423. "\n".join(
  2424. f"#include <ATen/ops/{f.root_name}_ops.h>\n"
  2425. # NB: this include is important as it ensures we
  2426. # set the visibility on generated view_copy kernels
  2427. # correctly
  2428. f"#include <ATen/ops/{f.root_name}_native.h>"
  2429. for f in (
  2430. [g.view] if g.view_copy is None else [g.view, g.view_copy]
  2431. )
  2432. )
  2433. for g in view_groups
  2434. ]
  2435. + [
  2436. "\n".join(
  2437. f"#include <ATen/ops/{f.root_name}_ops.h>\n"
  2438. # NB: this include is also important for correct visibility
  2439. f"#include <ATen/ops/{f.root_name}_native.h>"
  2440. for f in [g.inplace, g.mutable, g.functional]
  2441. if f is not None and "generated" not in f.tags
  2442. )
  2443. for g in structured_native_functions
  2444. ],
  2445. "CompositeViewCopyKernel_Definitions": list(
  2446. mapMaybe(
  2447. GenCompositeViewCopyKernel(
  2448. backend_indices[
  2449. DispatchKey.CompositeExplicitAutogradNonFunctional
  2450. ]
  2451. ),
  2452. view_groups,
  2453. )
  2454. ),
  2455. "GeneratedCompositeFunctional_Definitions": list(
  2456. mapMaybe(
  2457. gen_composite_functional_kernel,
  2458. structured_native_functions,
  2459. )
  2460. ),
  2461. "GeneratedCompositeOut_Definitions": list(
  2462. mapMaybe(
  2463. gen_composite_out_kernel,
  2464. structured_native_functions,
  2465. )
  2466. ),
  2467. },
  2468. )
  2469. def gen_declarations_yaml(
  2470. cpu_fm: FileManager, native_functions: Sequence[NativeFunction]
  2471. ) -> None:
  2472. cpu_fm.write(
  2473. "Declarations.yaml",
  2474. lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]),
  2475. )
  2476. def get_torchgen_root() -> Path:
  2477. """
  2478. If you're depending on torchgen out-of-tree, you can use the root to figure
  2479. out the path to native_functions.yaml
  2480. """
  2481. return Path(__file__).parent.resolve()
  2482. def main() -> None:
  2483. parser = argparse.ArgumentParser(description="Generate ATen source files")
  2484. parser.add_argument(
  2485. "-s",
  2486. "--source-path",
  2487. help="path to source directory for ATen",
  2488. default="aten/src/ATen",
  2489. )
  2490. parser.add_argument(
  2491. "-o",
  2492. "--output-dependencies",
  2493. help="output a list of dependencies into the given file and exit",
  2494. )
  2495. parser.add_argument(
  2496. "--dry-run",
  2497. action="store_true",
  2498. help="run without writing any files (still updates outputs)",
  2499. )
  2500. parser.add_argument(
  2501. "--per-operator-headers",
  2502. action="store_true",
  2503. help="generate separate headers per operator in ATen/ops",
  2504. )
  2505. parser.add_argument(
  2506. "-d",
  2507. "--install-dir",
  2508. "--install_dir",
  2509. help="output directory",
  2510. default="build/aten/src/ATen",
  2511. )
  2512. parser.add_argument(
  2513. "--aoti-install-dir",
  2514. "--aoti_install_dir",
  2515. help="output directory for AOTInductor shim",
  2516. default="torch/csrc/inductor/aoti_torch/generated",
  2517. )
  2518. parser.add_argument(
  2519. "--rocm",
  2520. action="store_true",
  2521. help="reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly",
  2522. )
  2523. parser.add_argument(
  2524. "--mps",
  2525. action="store_true",
  2526. help="Generate MPS registration code when set",
  2527. )
  2528. parser.add_argument(
  2529. "--xpu",
  2530. action="store_true",
  2531. help="Generate XPU registration code when set",
  2532. )
  2533. parser.add_argument(
  2534. "--mtia",
  2535. action="store_true",
  2536. help="Generate MTIA registration code when set",
  2537. )
  2538. # TODO: --op-registration-whitelist will be removed when all call-sites
  2539. # for gen.py are moved over to using the operator YAML file for mobile
  2540. # custom build.
  2541. parser.add_argument(
  2542. "--op-registration-whitelist",
  2543. "--op_registration_whitelist",
  2544. nargs="*",
  2545. help="filter op registrations by the whitelist (if set); "
  2546. "each item is `namespace`::`operator name` without overload name; "
  2547. "e.g.: aten::empty aten::conv2d ...",
  2548. )
  2549. parser.add_argument(
  2550. "--op-selection-yaml-path",
  2551. "--op_selection_yaml_path",
  2552. help="Provide a path to the operator selection (for custom build) YAML "
  2553. "that contains the information about the set of selected operators "
  2554. "and their categories (training, ...). Each operator is either a "
  2555. "full operator name with overload or just a bare operator name. "
  2556. "The operator names also contain the namespace prefix (e.g. aten::)",
  2557. )
  2558. parser.add_argument(
  2559. "--backend-whitelist",
  2560. "--backend_whitelist",
  2561. nargs="*",
  2562. help="filter dispatch backend by the whitelist (if set), "
  2563. "e.g.: CPU CUDA QuantizedCPU ...",
  2564. )
  2565. parser.add_argument(
  2566. "--static-dispatch-backend",
  2567. "--static_dispatch_backend",
  2568. nargs="*",
  2569. help="generate static dispatch code for the specific backend (if set)",
  2570. )
  2571. parser.add_argument(
  2572. "--skip-dispatcher-op-registration",
  2573. "--skip_dispatcher_op_registration",
  2574. action="store_true",
  2575. help="Avoid registering operators into the dispatcher.",
  2576. )
  2577. parser.add_argument(
  2578. "--force-schema-registration",
  2579. "--force_schema_registration",
  2580. action="store_true",
  2581. help="force it to generate schema-only registrations for all ops, including"
  2582. "those that are not listed on --op-registration-whitelist",
  2583. )
  2584. parser.add_argument(
  2585. "--generate",
  2586. type=str,
  2587. nargs="*",
  2588. choices=["headers", "sources", "declarations_yaml"],
  2589. default=["headers", "sources", "declarations_yaml"],
  2590. help="Generate only a subset of files",
  2591. )
  2592. parser.add_argument(
  2593. "--update-aoti-c-shim",
  2594. action="store_true",
  2595. help="Update AOTInductor C shim after adding an entry to inductor_fallback_ops in torchgen/aoti/fallback_ops.py. "
  2596. "WARNING: Do not use this unless you are sure what you are doing!!!",
  2597. )
  2598. parser.add_argument(
  2599. "--extend-aoti-c-shim",
  2600. action="store_true",
  2601. help="This Flag indicates the generation of c shims for out-of-tree ATen ops,"
  2602. "which is an extension to the In-tree ATen op c shims. This flag needs to be combined with"
  2603. "---source-path=<out-of-tree native_functions.yaml>"
  2604. "--aoti-install-dir=<in-tree aoti_install_dir>/extend"
  2605. " default is torch/csrc/inductor/aoti_torch/generated/extend"
  2606. "WARNING: Do not use this unless you are sure what you are doing!!!",
  2607. )
  2608. options = parser.parse_args()
  2609. selector = get_custom_build_selector(
  2610. options.op_registration_whitelist,
  2611. options.op_selection_yaml_path,
  2612. )
  2613. native_yaml_path = os.path.join(options.source_path, "native/native_functions.yaml")
  2614. tags_yaml_path = os.path.join(options.source_path, "native/tags.yaml")
  2615. from torchgen.model import dispatch_keys
  2616. # Only a limited set of dispatch keys get CPUFunctions.h headers generated
  2617. # for them; this is the set
  2618. functions_keys = {
  2619. DispatchKey.CPU,
  2620. DispatchKey.CUDA,
  2621. DispatchKey.CompositeImplicitAutograd,
  2622. DispatchKey.CompositeImplicitAutogradNestedTensor,
  2623. DispatchKey.CompositeExplicitAutograd,
  2624. DispatchKey.CompositeExplicitAutogradNonFunctional,
  2625. DispatchKey.Meta,
  2626. DispatchKey.MTIA,
  2627. }
  2628. aoti_backends = {
  2629. DispatchKey.CPU,
  2630. DispatchKey.CUDA,
  2631. # None will generate the aten shim based on aten_shimified_ops
  2632. # which does not bypass the dispatcher
  2633. None,
  2634. }
  2635. # TODO: stop generating CUDA kernels for non-CUDA builds
  2636. ignore_keys = set()
  2637. MPS_KEYS = {DispatchKey.MPS, DispatchKey.SparseMPS, DispatchKey.SparseCsrMPS}
  2638. if options.mps or options.update_aoti_c_shim:
  2639. functions_keys.update(MPS_KEYS)
  2640. aoti_backends.add(DispatchKey.MPS)
  2641. else:
  2642. ignore_keys.update(MPS_KEYS)
  2643. dispatch_keys[:] = [k for k in dispatch_keys if k not in MPS_KEYS]
  2644. if options.xpu or options.update_aoti_c_shim:
  2645. functions_keys.add(DispatchKey.XPU)
  2646. aoti_backends.add(DispatchKey.XPU)
  2647. else:
  2648. ignore_keys.add(DispatchKey.XPU)
  2649. if DispatchKey.XPU in dispatch_keys:
  2650. del dispatch_keys[dispatch_keys.index(DispatchKey.XPU)]
  2651. if not options.mtia:
  2652. ignore_keys.add(DispatchKey.MTIA)
  2653. if DispatchKey.MTIA in dispatch_keys:
  2654. del dispatch_keys[dispatch_keys.index(DispatchKey.MTIA)]
  2655. if options.backend_whitelist:
  2656. dispatch_keys = [
  2657. k
  2658. for k in dispatch_keys
  2659. if is_generic_dispatch_key(k) or str(k) in options.backend_whitelist
  2660. ]
  2661. parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys)
  2662. valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path]
  2663. native_functions, backend_indices = (
  2664. parsed_yaml.native_functions,
  2665. parsed_yaml.backend_indices,
  2666. )
  2667. grouped_native_functions = get_grouped_native_functions(native_functions)
  2668. structured_native_functions = [
  2669. g for g in grouped_native_functions if isinstance(g, NativeFunctionsGroup)
  2670. ]
  2671. native_functions_with_view_groups = get_grouped_by_view_native_functions(
  2672. native_functions
  2673. )
  2674. view_groups = [
  2675. g
  2676. for g in native_functions_with_view_groups
  2677. if isinstance(g, NativeFunctionsViewGroup)
  2678. ]
  2679. # NB: It is mandatory to NOT use os.path.join here, as the install directory
  2680. # will eventually be ingested by cmake, which does not respect Windows style
  2681. # path slashes. If you switch this to use os.path.join, you'll get an error
  2682. # like:
  2683. #
  2684. # Syntax error in cmake code when parsing string
  2685. #
  2686. # C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
  2687. #
  2688. # Invalid character escape '\c'.
  2689. core_install_dir = f"{options.install_dir}/core"
  2690. Path(core_install_dir).mkdir(parents=True, exist_ok=True)
  2691. ops_install_dir = f"{options.install_dir}/ops"
  2692. Path(ops_install_dir).mkdir(parents=True, exist_ok=True)
  2693. aoti_install_dir = f"{options.aoti_install_dir}"
  2694. Path(aoti_install_dir).mkdir(parents=True, exist_ok=True)
  2695. core_fm = make_file_manager(options=options, install_dir=core_install_dir)
  2696. cpu_fm = make_file_manager(options=options)
  2697. cpu_vec_fm = make_file_manager(options=options)
  2698. cuda_fm = make_file_manager(options=options)
  2699. ops_fm = make_file_manager(options=options, install_dir=ops_install_dir)
  2700. aoti_fm = make_file_manager(options=options, install_dir=aoti_install_dir)
  2701. device_fms = {"cuda": cuda_fm}
  2702. if options.xpu:
  2703. device_fms["xpu"] = make_file_manager(options=options)
  2704. static_dispatch_idx: list[BackendIndex] = []
  2705. if options.static_dispatch_backend:
  2706. static_dispatch_idx = [
  2707. backend_indices[DispatchKey.parse(key)]
  2708. for key in options.static_dispatch_backend
  2709. ]
  2710. for key in options.static_dispatch_backend:
  2711. dp_key = DispatchKey.parse(key)
  2712. if dp_key not in functions_keys:
  2713. functions_keys.add(dp_key)
  2714. if "sources" in options.generate:
  2715. gen_source_files(
  2716. native_functions=native_functions,
  2717. grouped_native_functions=grouped_native_functions,
  2718. structured_native_functions=structured_native_functions,
  2719. view_groups=view_groups,
  2720. selector=selector,
  2721. static_dispatch_idx=static_dispatch_idx,
  2722. backend_indices=backend_indices,
  2723. aoti_fm=aoti_fm,
  2724. core_fm=core_fm,
  2725. cpu_vec_fm=cpu_vec_fm,
  2726. cpu_fm=cpu_fm,
  2727. device_fms=device_fms,
  2728. dispatch_keys=dispatch_keys,
  2729. functions_keys=functions_keys,
  2730. rocm=options.rocm,
  2731. force_schema_registration=options.force_schema_registration,
  2732. per_operator_headers=options.per_operator_headers,
  2733. skip_dispatcher_op_registration=options.skip_dispatcher_op_registration,
  2734. update_aoti_c_shim=options.update_aoti_c_shim,
  2735. aoti_backends=aoti_backends,
  2736. extend_aoti_c_shim=options.extend_aoti_c_shim,
  2737. )
  2738. if "headers" in options.generate:
  2739. gen_headers(
  2740. native_functions=native_functions,
  2741. valid_tags=valid_tags,
  2742. grouped_native_functions=grouped_native_functions,
  2743. structured_native_functions=structured_native_functions,
  2744. static_dispatch_idx=static_dispatch_idx,
  2745. selector=selector,
  2746. backend_indices=backend_indices,
  2747. core_fm=core_fm,
  2748. cpu_fm=cpu_fm,
  2749. device_fms=device_fms,
  2750. ops_fm=ops_fm,
  2751. dispatch_keys=dispatch_keys,
  2752. functions_keys=functions_keys,
  2753. rocm=options.rocm,
  2754. per_operator_headers=options.per_operator_headers,
  2755. )
  2756. if "declarations_yaml" in options.generate:
  2757. gen_declarations_yaml(native_functions=native_functions, cpu_fm=cpu_fm)
  2758. if options.output_dependencies:
  2759. depfile_path = Path(options.output_dependencies).resolve()
  2760. depfile_name = depfile_path.name
  2761. depfile_stem = depfile_path.stem
  2762. for fm, prefix in [
  2763. (cpu_fm, ""),
  2764. (cpu_vec_fm, "cpu_vec_"),
  2765. (core_fm, "core_"),
  2766. (ops_fm, "ops_"),
  2767. ] + [(device_fm, f"{device}_") for device, device_fm in device_fms.items()]:
  2768. varname = prefix + depfile_stem
  2769. path = depfile_path.parent / (prefix + depfile_name)
  2770. fm.write_outputs(varname, str(path))
  2771. if __name__ == "__main__":
  2772. main()