graph.py 110 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652
  1. from __future__ import annotations
  2. import contextlib
  3. import functools
  4. import itertools
  5. import logging
  6. import operator
  7. import os
  8. import re
  9. import sys
  10. import time
  11. from collections import defaultdict
  12. from contextlib import contextmanager
  13. from typing import Any, NoReturn, Optional, TYPE_CHECKING, Union
  14. import sympy
  15. from sympy import Expr
  16. import torch
  17. import torch._logging
  18. import torch.fx
  19. from torch import device, Tensor
  20. from torch._decomp import get_decompositions
  21. from torch._dynamo.utils import defake, dynamo_timed
  22. from torch._library.fake_class_registry import FakeScriptObject
  23. from torch._library.opaque_object import is_opaque_type
  24. from torch._library.utils import get_layout_constraint_tag
  25. from torch._logging import LazyString, trace_structured
  26. from torch._prims_common import (
  27. compute_required_storage_length,
  28. make_channels_last_strides_for,
  29. )
  30. from torch._subclasses.fake_tensor import FakeTensor
  31. from torch._utils_internal import full_aoti_runtime_assert
  32. from torch.fx.experimental._backward_state import BackwardState
  33. from torch.fx.experimental.sym_node import magic_methods, method_to_operator
  34. from torch.fx.experimental.symbolic_shapes import (
  35. _get_placeholder_expr,
  36. free_unbacked_symbols,
  37. has_free_symbols,
  38. resolve_unbacked_bindings,
  39. RuntimeAssert,
  40. ShapeEnv,
  41. SympyBoolean,
  42. SymTypes,
  43. )
  44. from torch.fx.node import Node
  45. from torch.fx.passes.reinplace import _is_view_op
  46. from torch.utils._mode_utils import no_dispatch
  47. from torch.utils._ordered_set import OrderedSet
  48. from torch.utils._sympy.numbers import int_oo
  49. from . import config, ir, metrics
  50. from .codegen.common import (
  51. BackendFeature,
  52. DeviceOpOverrides,
  53. FileBackedGraphModule,
  54. get_backend_features,
  55. get_device_op_overrides,
  56. get_wrapper_codegen_for_device,
  57. init_backend_registration,
  58. WorkspaceArg,
  59. )
  60. from .exc import (
  61. CppWrapperCodegenError,
  62. LoweringException,
  63. MissingOperatorWithDecomp,
  64. MissingOperatorWithoutDecomp,
  65. )
  66. from .fx_utils import count_flops_fx
  67. from .ir import (
  68. assign_origin_node,
  69. Constant,
  70. DonatedBuffer,
  71. FixedLayout,
  72. get_device_type,
  73. GraphPartitionSignature,
  74. InputBuffer,
  75. Pointwise,
  76. Reduction,
  77. ShapeAsConstantBuffer,
  78. StorageBox,
  79. TensorBox,
  80. TorchBindObject,
  81. )
  82. from .lowering import (
  83. constrain_to_fake_tensors,
  84. constrain_to_fx_strides,
  85. FALLBACK_ALLOW_LIST,
  86. fallback_handler,
  87. fallback_node_due_to_unsupported_type,
  88. lowerings,
  89. make_fallback,
  90. maybe_layout_constraints,
  91. needs_realized_inputs,
  92. require_contiguous,
  93. tag_to_layout_constraint,
  94. unsupported_output_tensor,
  95. )
  96. from .runtime import autotune_cache
  97. from .runtime.autotune_cache import AutotuneCacheBundler
  98. from .sizevars import SizeVarAllocator
  99. from .utils import (
  100. convert_shape_to_inductor,
  101. gather_origins,
  102. get_cloned_parameter_buffer_name,
  103. get_donated_idxs,
  104. get_sympy_Expr_dtype,
  105. GraphPartitionMap,
  106. is_same_tensor,
  107. maybe_get_suppress_shape_guards_ctx,
  108. normalize_name,
  109. should_assume_input_aligned,
  110. should_fallback_by_default,
  111. SUPPORTED_MKLDNN_DEVICES,
  112. ValueWithLineMap,
  113. )
  114. from .virtualized import NullHandler, V
  115. if TYPE_CHECKING:
  116. from collections.abc import Callable, Iterable, Iterator, Sequence
  117. from types import ModuleType
  118. from torch._higher_order_ops.effects import _EffectType
  119. from torch.fx import GraphModule
  120. from torch.fx.graph import Graph
  121. from .codegen.wrapper import PythonWrapperCodegen
  122. from .dependencies import Dep
  123. from .scheduler import BaseSchedulerNode
  124. CompiledModule = Union[ModuleType, FileBackedGraphModule]
  125. from torch._inductor.codecache import output_code_log
  126. log = logging.getLogger(__name__)
  127. perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
  128. aten = torch.ops.aten
  129. _post_grad_graph_counter = itertools.count()
  130. if config.is_fbcode():
  131. from torch._inductor.fb.utils import log_module_code
  132. else:
  133. def log_module_code(*args: Any, **kwargs: Any) -> None:
  134. pass
  135. def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]:
  136. assert isinstance(
  137. constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
  138. ), (
  139. "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
  140. )
  141. if isinstance(constant_buffer, sympy.core.numbers.Integer):
  142. return torch.int64
  143. if isinstance(constant_buffer, sympy.Expr):
  144. return get_sympy_Expr_dtype(constant_buffer)
  145. if constant_buffer.is_integer:
  146. return torch.int64
  147. elif constant_buffer.is_float:
  148. return torch.float32
  149. else:
  150. return None
  151. def is_magic_method(op: Any) -> bool:
  152. magic_ops = OrderedSet(method_to_operator(m) for m in magic_methods)
  153. return op in magic_ops
  154. def getattr_recursive(
  155. obj: GraphModule, target: str
  156. ) -> Union[Tensor, torch._C.ScriptObject, GraphModule]:
  157. target_atoms = target.split(".")
  158. attr_itr = obj
  159. for i, atom in enumerate(target_atoms):
  160. if not hasattr(attr_itr, atom):
  161. raise RuntimeError(
  162. f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}"
  163. )
  164. attr_itr = getattr(attr_itr, atom)
  165. return attr_itr
  166. def get_user_visible_output_strides(g: Graph) -> dict[Node, tuple[int, ...]]:
  167. ret: dict[Node, tuple[int, ...]] = {}
  168. output_node = g.find_nodes(op="output")[0]
  169. if "user_visible_output_idxs" not in output_node.meta:
  170. return ret
  171. if not isinstance(output_node.args[0], torch.fx.Node):
  172. output_node_args = output_node.args[0]
  173. else:
  174. output_node_args = output_node.args
  175. for idx, node in enumerate(output_node_args):
  176. if idx in output_node.meta["user_visible_output_idxs"]:
  177. ret[node] = output_node.meta["original_output_strides"][idx]
  178. return ret
  179. def extend_user_visible_output_strides(
  180. user_visible_outputs: dict[Node, tuple[int, ...]],
  181. ) -> dict[Node, object]:
  182. """
  183. Extend user_visible_output_strides to include view ops that lead to user-visible outputs.
  184. """
  185. result: dict[Node, object] = {**user_visible_outputs}
  186. queue = [*result.keys()]
  187. visited = OrderedSet([*queue])
  188. while queue:
  189. current = queue.pop()
  190. if (
  191. _is_view_op(current.target)
  192. and current.args
  193. and isinstance(current.args[0], torch.fx.Node)
  194. ):
  195. base = current.args[0]
  196. if base not in visited:
  197. result.setdefault(base, None)
  198. visited.add(base)
  199. queue.append(base)
  200. return result
  201. def mark_nodes_dislike_padding(
  202. g: Graph, user_visible_output_strides: dict[Node, tuple[int, ...]]
  203. ) -> None:
  204. """
  205. Nodes like convolution/convolution_backward want its input to be dense.
  206. If we pad their inputs, we result in extra calls to copy kernels! On the other hand, padding usually helps reduction.
  207. The pass finds nodes that dislike padding. These are nodes that can be reached
  208. from a convolution/convolution_backward in the backward direction without
  209. going thru a reduction.
  210. """
  211. if not config.comprehensive_padding:
  212. return
  213. extended_user_visible_nodes = extend_user_visible_output_strides(
  214. user_visible_output_strides
  215. )
  216. ops_dislike_padding = OrderedSet(
  217. [
  218. aten.convolution,
  219. aten.convolution_backward,
  220. aten._scaled_mm,
  221. ]
  222. )
  223. # what's a better way to collect the reduction ops?
  224. ops_like_padding = OrderedSet(
  225. [
  226. aten.var_mean,
  227. aten.sum,
  228. aten.mean,
  229. aten.prod,
  230. aten.any,
  231. aten.amin,
  232. aten.amax,
  233. aten.min,
  234. aten.max,
  235. aten.argmin,
  236. aten.argmax,
  237. aten.scatter_reduce,
  238. ]
  239. )
  240. def _get_overload_packet(
  241. node: torch.fx.Node,
  242. ) -> Optional[torch._ops.OpOverloadPacket]:
  243. return (
  244. node.target._overloadpacket
  245. if node.op == "call_function"
  246. # hasattr on OpOverloadPacket is slow, do isinstance first
  247. and isinstance(node.target, torch._ops.OpOverload)
  248. and hasattr(node.target, "_overloadpacket")
  249. else None
  250. )
  251. for cur in reversed(g.nodes):
  252. if isinstance(
  253. cur.target,
  254. torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation,
  255. ):
  256. cur.meta["dislike_padding"] = True
  257. continue
  258. if (
  259. isinstance(cur.target, torch._ops.OpOverload)
  260. and get_layout_constraint_tag(cur.target)
  261. == torch._C.Tag.needs_exact_strides
  262. ):
  263. cur.meta["dislike_padding"] = True
  264. continue
  265. op = _get_overload_packet(cur)
  266. if not op:
  267. continue
  268. if op in ops_dislike_padding:
  269. cur.meta["dislike_padding"] = True
  270. if cur.meta.get("dislike_padding", False):
  271. # propagate
  272. for prior in cur.all_input_nodes:
  273. prior_op = _get_overload_packet(prior)
  274. if not prior_op:
  275. continue
  276. if prior_op not in ops_like_padding:
  277. prior.meta["dislike_padding"] = True
  278. # We only want to mark output nodes. So, move it after the above prior nodes process.
  279. if not config.pad_outputs and cur in extended_user_visible_nodes:
  280. cur.meta["dislike_padding"] = True
  281. def is_mkldnn_conv(node: Node) -> bool:
  282. # When mkldnn_fusion is enabled, conv will be replaced by the lowering pattern function.
  283. # See _register_unary_fusion_lowering in torch/_inductor/fx_passes/mkldnn_fusion.py.
  284. if (
  285. getattr(torch.ops, "mkldnn", None) is not None
  286. and getattr(torch.ops.mkldnn, "_convolution_pointwise", None) is not None
  287. and isinstance(node.target, functools.partial)
  288. and len(node.target.args) > 0
  289. and hasattr(node.target.args[0], "targets")
  290. ):
  291. for target in node.target.args[0].targets:
  292. if target.fns[0] in [
  293. torch.ops.mkldnn._convolution_pointwise.default,
  294. torch.ops.mkldnn._convolution_pointwise.binary,
  295. torch.ops.mkldnn._convolution_pointwise_.binary,
  296. ]:
  297. return True
  298. return False
  299. class GraphLowering(torch.fx.Interpreter):
  300. graph_outputs: list[ir.IRNode]
  301. def __init__(
  302. self,
  303. gm: torch.fx.GraphModule,
  304. example_inputs: Optional[Sequence[object]] = None,
  305. shape_env: Optional[ShapeEnv] = None,
  306. graph_id: Optional[int] = None,
  307. cpp_wrapper: bool = False,
  308. aot_mode: bool = False,
  309. layout_opt: Optional[bool] = None,
  310. extern_node_serializer: Optional[
  311. Callable[[list[ir.ExternKernelNode]], Any]
  312. ] = None,
  313. is_inference: bool = False,
  314. is_backward: bool = False,
  315. is_const_graph: bool = False,
  316. const_output_index: Optional[dict[str, int]] = None,
  317. const_wrapper_code: Optional[str] = None,
  318. const_kernel_code: Optional[str] = None,
  319. const_module: Optional[GraphLowering] = None,
  320. name: Optional[str] = None,
  321. inputs_to_check: Optional[Sequence[int]] = None,
  322. fx_wrapper: bool = False,
  323. ) -> None:
  324. super().__init__(gm)
  325. self.example_inputs = example_inputs
  326. self.layout_opt = (
  327. layout_opt
  328. if layout_opt is not None
  329. else self.decide_layout_opt(gm, is_inference=is_inference)
  330. )
  331. self.num_channels_last_conv = 0
  332. self.is_inference = is_inference
  333. self.is_backward = is_backward
  334. self.is_const_graph = is_const_graph
  335. self.const_wrapper_code = const_wrapper_code
  336. self.const_kernel_code = const_kernel_code
  337. self.const_module = const_module
  338. self.inputs_to_check = inputs_to_check
  339. self.extra_traceback = False # we do our own error wrapping
  340. if shape_env is None:
  341. shape_env = ShapeEnv()
  342. self.reuse_shape_env = False
  343. else:
  344. self.reuse_shape_env = True
  345. self._shape_env = shape_env
  346. # We're going to mutate ras_by_symbol as we finish generating them
  347. self.ras_by_symbol: dict[Optional[sympy.Symbol], list[RuntimeAssert]] = (
  348. shape_env.deferred_runtime_asserts.copy()
  349. )
  350. self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
  351. self.sizevars = SizeVarAllocator(shape_env)
  352. self.graph_input_names: list[str] = []
  353. self.graph_inputs: dict[str, Union[TensorBox, TorchBindObject, sympy.Expr]] = {}
  354. self.graph_inputs_original: dict[str, InputBuffer] = {}
  355. self.partition_maps: Optional[list[GraphPartitionMap]] = None
  356. self.zero_dim_cpu_tensor_list: OrderedSet[str] = OrderedSet()
  357. self.device_types: OrderedSet[str] = (
  358. const_module.device_types if const_module else OrderedSet()
  359. )
  360. self.device_idxs: OrderedSet[int] = (
  361. const_module.device_idxs if const_module else OrderedSet()
  362. )
  363. self.device_type = "cpu"
  364. self.additional_buffer_deps: dict[str, OrderedSet[str]] = defaultdict(
  365. OrderedSet
  366. )
  367. self.additional_star_deps: dict[str, OrderedSet[str]] = defaultdict(OrderedSet)
  368. # Inplace padding may require Inductor to allocate slightly larger
  369. # tensor for padding.
  370. self.buffer_to_padded_size: dict[str, list[int]] = {}
  371. self.buffers: list[ir.Buffer] = []
  372. self.operations: list[ir.Operation] = []
  373. self.const_output_index: dict[str, int] = (
  374. const_output_index if const_output_index else {}
  375. )
  376. self.folded_constants: OrderedSet[str] = (
  377. OrderedSet(const_output_index.keys())
  378. if const_output_index
  379. else OrderedSet()
  380. )
  381. self.constants: dict[str, torch.Tensor] = (
  382. const_module.constants if const_module else {}
  383. )
  384. self.named_buffers: dict[str, torch.Tensor] = (
  385. const_module.named_buffers if const_module else {}
  386. )
  387. self.mutated_named_buffers: OrderedSet[torch.Tensor] = gm.meta.get(
  388. "mutated_named_buffers", OrderedSet()
  389. )
  390. self.named_parameters: dict[str, torch.Tensor] = (
  391. const_module.named_parameters if const_module else {}
  392. )
  393. self.torchbind_constants: dict[
  394. str, Union[torch._C.ScriptObject, FakeScriptObject]
  395. ] = {}
  396. self.opaque_value_type_classes: dict[str, type] = {}
  397. self.seen_subgraphs: dict[str, ir.Subgraph] = {}
  398. self.constant_reprs: dict[str, str] = {}
  399. self.removed_operations: OrderedSet[str] = OrderedSet()
  400. self.removed_buffers: OrderedSet[str] = OrderedSet()
  401. self.removed_inplace_buffers: OrderedSet[str] = OrderedSet()
  402. self.mutated_buffers: OrderedSet[str] = OrderedSet()
  403. self.never_reuse_buffers: OrderedSet[str] = OrderedSet()
  404. self.inplaced_to_remove: OrderedSet[str] = OrderedSet()
  405. self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
  406. self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
  407. from torch._inductor.extern_node_serializer import extern_node_json_serializer
  408. self.extern_node_serializer: Callable[[list[ir.ExternKernelNode]], Any] = (
  409. extern_node_serializer
  410. if config.is_fbcode() and extern_node_serializer
  411. else extern_node_json_serializer
  412. )
  413. self.current_node: torch.fx.Node = None # type: ignore[assignment]
  414. self.lists: dict[str, list[str]] = {}
  415. self.mutated_inputs: OrderedSet[str] = OrderedSet()
  416. self.mutated_input_idxs: list[int] = []
  417. self.name_to_buffer: dict[str, ir.Buffer] = {}
  418. self.name_to_users: defaultdict[str, list[ir.IRNode]] = defaultdict(list)
  419. self.name_to_op: dict[str, ir.Operation] = {}
  420. self.creation_time = time.time()
  421. self.name = name # type: ignore[assignment]
  422. self.cpp_wrapper = cpp_wrapper
  423. self.fx_wrapper = fx_wrapper
  424. # record multi_kernel choice for cpp_wrapper so the second pass knows
  425. # which sub-kernel is picked. Copy cpp_wrapper to another variable
  426. # since cpp_wrapper flag is OrderedSet to false for the first pass of codegen.
  427. self.record_multi_kernel_choice = cpp_wrapper
  428. self.multi_kernel_to_choice: dict[str, str] = {}
  429. self.aot_mode = aot_mode
  430. self.graph_id = graph_id
  431. self.post_grad_graph_id = next(_post_grad_graph_counter)
  432. self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment]
  433. # record intermediate results for input of UsedDefinedTritonKernels
  434. # This will be used if autotuning is done in one pass.
  435. self.autotuning_inputs: Optional[list[torch.Tensor]] = None
  436. self.autotuning_mapping: Optional[dict[str, dict[str, int]]] = None
  437. self.autotuning_grids: Optional[dict[str, Any]] = None
  438. # current_device is set only during codegen of a device-specific kernel
  439. # a graph can have many devices
  440. self.current_device: Optional[torch.device] = None
  441. self.nodes_prefer_channels_last = (
  442. self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet()
  443. )
  444. self._warned_fallback = OrderedSet(["aten.convolution_backward"])
  445. self.user_visible_output_strides = get_user_visible_output_strides(gm.graph)
  446. mark_nodes_dislike_padding(gm.graph, self.user_visible_output_strides)
  447. self.cache_key: str = "" # This is the cache key for the compiled artifact
  448. self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
  449. self.cache_linemap: list[
  450. tuple[int, str]
  451. ] = [] # This is the linemap used by the profiler to mark custom compiled kernels getting run
  452. # Used if lowering encounters cases where cudagraphs are not supported
  453. self.disable_cudagraphs_reason: Optional[str] = None
  454. # only keeping one node per device for stack trace purposes
  455. self.device_node_mapping: dict[torch.device, torch.fx.Node] = {}
  456. self.orig_gm: torch.fx.GraphModule = gm.__copy__()
  457. for k, v in self.orig_gm.named_buffers():
  458. self.named_buffers[k] = v
  459. for k, v in self.orig_gm.named_parameters():
  460. self.named_parameters[k] = v
  461. self.dynamo_flat_name_to_original_fqn = self.module.meta.get( # type: ignore[operator, union-attr]
  462. "dynamo_flat_name_to_original_fqn", {}
  463. )
  464. self.allocated_constant_name: dict[str, str] = (
  465. const_module.allocated_constant_name if const_module is not None else {}
  466. )
  467. init_backend_registration()
  468. self.get_backend_features = functools.lru_cache(None)(get_backend_features)
  469. self.effectful_ops: dict[_EffectType, ir.Buffer] = {}
  470. # Track the buffers that we know is unaligned
  471. # This can either be a graph input or the output of fallback
  472. # kernels.
  473. self.unaligned_buffers: OrderedSet[str] = OrderedSet()
  474. self.no_fuse_buffer_names: OrderedSet[str] = OrderedSet()
  475. # Layout constraints for Triton template buffers.
  476. # Maps buffer name -> expected FixedLayout (computed speculatively without freezing)
  477. self.buffer_layout_constraints: dict[str, ir.FixedLayout] = {}
  478. self.low_precision_codegen_ops: OrderedSet[str] = OrderedSet()
  479. # more aggressive prologue fusion
  480. self.invoke_quant_ops: OrderedSet[str] = OrderedSet()
  481. # Below field is related to printing debug intermediate tensor values info for debugging
  482. self.all_codegen_kernel_names: OrderedSet[str] = OrderedSet()
  483. # state used by for KernelArgs.workspace
  484. self.workspace_id = itertools.count()
  485. # track the current placeholder index that we are processing
  486. self.placeholder_idx = -1
  487. self.bw_donated_idxs = get_donated_idxs()
  488. # Cache for dep size hints to avoid expensive recomputation
  489. self.dep_size_hint_cache: dict[tuple[Dep, bool], int] = {}
  490. def freeze_runtime_asserts(self) -> None:
  491. self._shape_env.freeze_runtime_asserts()
  492. def symbolic_sizes_strides(
  493. self, ex: torch.Tensor
  494. ) -> tuple[Sequence[Union[int, Expr]], Sequence[Union[int, Expr]]]:
  495. """
  496. Support dynamic shapes and dynamic strides by assigning variables
  497. to each dimension. We duck-shape tensors, so if two tensors
  498. have the same size they get assigned the same symbolic variable.
  499. """
  500. if self.reuse_shape_env:
  501. return convert_shape_to_inductor(ex.size()), convert_shape_to_inductor(
  502. ex.stride()
  503. )
  504. else:
  505. from torch._dynamo.source import ConstantSource
  506. # TODO: this should not be needed once #93059 lands
  507. # https://github.com/pytorch/pytorch/pull/94031#discussion_r1096044816
  508. # TODO: make a dedicated UnknownSource for this?
  509. # NB: This is using the legacy default behavior from
  510. # create_symbolic_sizes_strides_storage_offset but we hope we can
  511. # just delete this entirely
  512. source = ConstantSource(
  513. f"__inductor_unknown_tensor_{len(self._shape_env.backed_var_to_val)}"
  514. )
  515. (
  516. size,
  517. stride,
  518. _,
  519. ) = self._shape_env.create_symbolic_sizes_strides_storage_offset(
  520. ex,
  521. source,
  522. )
  523. r_size = [i.node.expr if isinstance(i, torch.SymInt) else i for i in size]
  524. r_stride = [i.node.expr if isinstance(i, torch.SymInt) else i for i in stride]
  525. return r_size, r_stride
  526. def static_sizes_strides(
  527. self, ex: torch.Tensor
  528. ) -> tuple[list[sympy.Expr], list[sympy.Expr]]:
  529. """
  530. Primarily used to weights
  531. """
  532. size = [sympy.Integer(i) for i in ex.size()]
  533. stride = [sympy.Integer(i) for i in ex.stride()]
  534. return size, stride
  535. def get_allocation_size(
  536. self,
  537. node: Union[
  538. ir.TensorBox, ir.StorageBox, ir.Buffer, WorkspaceArg, ir.TorchBindObject
  539. ],
  540. ) -> Sequence[Expr]:
  541. if isinstance(node, ir.TensorBox):
  542. node = node.data # type: ignore[assignment]
  543. if isinstance(node, ir.StorageBox):
  544. node = node.data # type: ignore[assignment]
  545. if (
  546. isinstance(node, ir.ComputedBuffer)
  547. and node.name in self.buffer_to_padded_size
  548. ):
  549. # pyrefly: ignore [bad-index, index-error]
  550. return self.buffer_to_padded_size[node.name]
  551. else:
  552. return node.get_size()
  553. def get_allocation_storage_size(
  554. self, node: Union[ir.Buffer, WorkspaceArg, ir.TorchBindObject]
  555. ) -> Expr:
  556. layout = node.get_layout()
  557. size = self.get_allocation_size(node) # consider inplace padding
  558. stride = layout.stride
  559. offset = layout.offset
  560. return compute_required_storage_length(size, stride, offset) # type: ignore[arg-type]
  561. def has_feature(
  562. self,
  563. device: Union[torch._inductor.ir.IRNode, device, None],
  564. feature: BackendFeature,
  565. ) -> bool:
  566. assert isinstance(feature, BackendFeature), feature
  567. return feature in self.get_backend_features(get_device_type(device))
  568. def get_dep_size_hint(self, dep: Dep, count_bytes: bool = True) -> int:
  569. """
  570. Get the size hint for a dependency with caching to avoid expensive recomputation.
  571. """
  572. if (dep, count_bytes) not in self.dep_size_hint_cache:
  573. res = 0
  574. try:
  575. if not dep.has_unbacked_symbols():
  576. if count_bytes:
  577. res = dep.numbytes_hint()
  578. else:
  579. res = dep.numel_hint()
  580. except KeyError:
  581. # In at least one test (test/inductor/test_torchbind.py) we
  582. # create a StarDep that doesn't exist in the graph and calling
  583. # `has_unbacked_symbols()` throws an error.
  584. pass
  585. self.dep_size_hint_cache[(dep, count_bytes)] = res
  586. return self.dep_size_hint_cache[(dep, count_bytes)]
  587. def get_current_device_or_throw(self) -> torch.device:
  588. if device := self.current_device:
  589. return device
  590. else:
  591. raise RuntimeError("No current device")
  592. @contextlib.contextmanager
  593. def set_current_device(self, device: torch.device) -> Iterator[None]:
  594. prior = self.current_device
  595. self.current_device = device
  596. try:
  597. yield
  598. finally:
  599. self.current_device = prior
  600. def get_training_phase(self) -> str:
  601. if self.is_inference:
  602. return "inference"
  603. if self.is_backward:
  604. return "backward"
  605. return "forward"
  606. @staticmethod
  607. def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool:
  608. """
  609. Decide if we should enable layout optimization for this graph based on
  610. heuristics.
  611. """
  612. if not config.layout_optimization:
  613. return False
  614. if config.force_layout_optimization:
  615. return True
  616. conv_nodes = [
  617. n for n in gm.graph.nodes if n.target is torch.ops.aten.convolution.default
  618. ]
  619. for n in gm.graph.nodes:
  620. if is_mkldnn_conv(n):
  621. conv_nodes.append(n)
  622. nconv = len(conv_nodes)
  623. if nconv == 0:
  624. return False
  625. # For cpu backend and mkldnn enabled, we always use channels_last for better performance.
  626. if (
  627. torch.backends.mkldnn.enabled # pyrefly: ignore [unbound-name]
  628. and torch.backends.mkldnn.is_available() # pyrefly: ignore [unbound-name]
  629. and all(
  630. n.args[idx].meta["val"].device.type in SUPPORTED_MKLDNN_DEVICES
  631. for n in conv_nodes
  632. for idx in [0, 1]
  633. )
  634. ):
  635. return True
  636. # Following models are skipped due to this:
  637. # jx_nest_base
  638. # volo_d1_224
  639. if len(list(gm.graph.nodes)) >= 300 * nconv:
  640. log.debug("Skipped layout opt because only a few conv")
  641. return False
  642. if any(
  643. has_free_symbols(n.args[idx].meta["val"])
  644. for n in conv_nodes
  645. for idx in [0, 1]
  646. ):
  647. log.debug(
  648. "See perf regression with dynamic shape. Follow up in https://github.com/pytorch/pytorch/issues/102670"
  649. )
  650. return False
  651. def is_grouped(n: Any) -> bool:
  652. meta_val = n.args[1].meta["val"] # type: ignore[union-attr, operator]
  653. assert isinstance(meta_val, torch.Tensor)
  654. return n.args[-1] > 1 and meta_val.size(1) > 1 # type: ignore[union-attr, operator]
  655. def is_in_out_channel(n: torch.fx.Node) -> bool:
  656. return (
  657. n.args[1].meta["val"].size(0) * 2 <= n.args[1].meta["val"].size(1) # type: ignore[union-attr, operator]
  658. and n.args[1].meta["val"].size(2) > 1 # type: ignore[union-attr, operator]
  659. )
  660. def is_small_channel(n: torch.fx.Node) -> bool:
  661. return (
  662. n.args[1].meta["val"].size(0) <= 64 # type: ignore[union-attr, operator]
  663. and n.args[1].meta["val"].size(1) <= 64 # type: ignore[union-attr, operator]
  664. )
  665. # only grouped convolutions benchmarked as slower in conv samples for inference only
  666. if is_inference:
  667. flop_counts: dict[str, float] = defaultdict(float)
  668. for node in conv_nodes:
  669. counted_flops = count_flops_fx(node)
  670. if counted_flops is None:
  671. continue
  672. if is_grouped(node):
  673. node_type = "grouped"
  674. elif is_small_channel(node):
  675. node_type = "small"
  676. elif is_in_out_channel(node):
  677. node_type = "in_out"
  678. else:
  679. node_type = "default"
  680. flop_counts[node_type] += counted_flops
  681. else:
  682. log.debug("Conv inputs meta not found")
  683. # average benchmarked channels last speedup / slowdown, < 1 is speedup.
  684. # taken from the set of convolution inputs in benchmarks/dynamo/microbenchmarks/operator_inp_logs/torchbench_train/
  685. # To regenerate these numbers follow https://gist.github.com/eellison/55d7a6ed6f39829d68ac56f95f4df5bb
  686. GROUPED_MULTIPLIER = 1.358
  687. DEFAULT_MULTIPLIER = 0.823
  688. IN_OUT_MULTIPLIER = 0.725
  689. SMALL_MULTIPLIER = 0.783
  690. total_flops = sum(flop_counts.values())
  691. # TODO - get different values per hardware
  692. weighted_flops = (
  693. flop_counts["grouped"] * GROUPED_MULTIPLIER
  694. + flop_counts["small"] * SMALL_MULTIPLIER
  695. + flop_counts["in_out"] * IN_OUT_MULTIPLIER
  696. + flop_counts["default"] * DEFAULT_MULTIPLIER
  697. )
  698. do_layout_opt = weighted_flops <= total_flops
  699. if not do_layout_opt:
  700. log.debug(
  701. "Skipped layout opt in inference because weighted flops indicate slowdown, default: %d, channels last: %d",
  702. total_flops,
  703. weighted_flops,
  704. )
  705. return do_layout_opt
  706. # Channels last layout can dramatically hurt grouped conv perf. E.g.
  707. # Conv with arguments like
  708. # {"input_shape": [32, 224, 112, 112], "weight_shape": [224, 112, 3, 3],
  709. # "stride": [2, 2], "padding": [1, 1], "groups": 2}
  710. # slows down 31x using channels last..
  711. # But a lot of timm models use depthwise separable convolution which will
  712. # result in grouped convolution with in-channel size == 1.
  713. # For those grouped convolution, channels last still helps a lot.
  714. # E.g.
  715. # Conv with arguments
  716. # {"input_shape": [128, 58, 56, 56], "weight_shape": [58, 1, 3, 3],
  717. # "stride": [2, 2], "padding": [1, 1], "groups": 58}
  718. # get 1.86x speedup with channels last layout.
  719. #
  720. # The following heuristics skip using channels-last if the model contains
  721. # grouped convolution with in-channels > 1.
  722. if any(map(is_grouped, conv_nodes)):
  723. log.debug(
  724. "Skip layout opt because found grouped convolution with >1 in_channels!"
  725. )
  726. return False
  727. # For some models that contain convolution with larger in-channel than out-channel, applying
  728. # channels last hurts performance.
  729. # Following models are skipped due to this:
  730. # - pytorch_unet
  731. # - phlippe_densenet (slightly worse)
  732. # - Background_Matting (1.22x -> 0.821x)
  733. # - pytorch_CycleGAN_and_pix2pix (1.597x -> 1.294x)
  734. if any(map(is_in_out_channel, conv_nodes)):
  735. log.debug(
  736. "Skip layout opt because some convolutions have smaller out_channel"
  737. )
  738. return False
  739. # Following models are skipped due to this:
  740. # - functorch_maml_omniglot
  741. if all(map(is_small_channel, conv_nodes)):
  742. log.debug("Skip layout opt because all convolution channels are too small")
  743. return False
  744. return True
  745. def qualify_name(self, name: str) -> str:
  746. """Prepend the given name with the graph name if any."""
  747. if self.name is not None:
  748. return f"{self.name}_{name}"
  749. return name
  750. def make_subgraph(
  751. self,
  752. gm: torch.fx.GraphModule,
  753. example_inputs: list[torch.Tensor],
  754. subgraph_name: str,
  755. ) -> SubgraphLowering:
  756. """
  757. Make a subgraph of the current graph with all inherited parts, except
  758. the graph module (`gm`) and `example_inputs`. The subgraphs are lowered
  759. separately and lifted into a separate function in the parent output
  760. wrapper code. The subgraph name is qualified by the parent graph's
  761. name. Note that the lifting of subgraph is supported for python wrapper
  762. only. For cpp wrapper, we inline the subgraphs in the parent wrapper.
  763. """
  764. return SubgraphLowering(
  765. parent=self,
  766. gm=gm,
  767. example_inputs=example_inputs,
  768. shape_env=self._shape_env,
  769. cpp_wrapper=self.cpp_wrapper,
  770. aot_mode=self.aot_mode,
  771. extern_node_serializer=self.extern_node_serializer,
  772. is_inference=self.is_inference,
  773. is_backward=self.is_backward,
  774. name=self.qualify_name(subgraph_name),
  775. )
  776. def find_nodes_prefer_channels_last(self) -> OrderedSet[Node]:
  777. """
  778. The rule to decide if an node prefer channels last is simple.
  779. 1. if it's input/output of a convolution
  780. 2. if one of its user prefers channels last
  781. We have rule 1 because cudnn runs a faster convolution kernel for channels last inputs;
  782. Rule 2 is also important. It makes sure that indirect inputs to convolution also prefers
  783. channels last.
  784. Consider the scenario: conv -> batch-norm -> relu -> conv
  785. Without rule 2, batch-norm output may use a contiguous layout. That will cause 2 extra copies:
  786. 1. the output of batch-norm should be channels last initially since its input is a conv's output.
  787. Forcing the batch-norm's output to be contiguous results in the first copy
  788. 2. The second conv's input is initially contiguous. This layout is propagated from the batch-norm's output.
  789. We need convert it to channels last layout which results in the second copy.
  790. With rule 2, we makes sure all the tensors in the chain uses channels last layout. So both copies
  791. can be saved.
  792. """
  793. last_conv = None
  794. nodes_cannot_propagate = [torch.ops.aten.bmm.default]
  795. output_set = OrderedSet[Node]()
  796. for n in reversed(self.module.graph.nodes): # type: ignore[arg-type, union-attr]
  797. if n.target is torch.ops.aten.convolution.default:
  798. output_set.add(n)
  799. if last_conv is None:
  800. last_conv = n
  801. continue
  802. if n.target in nodes_cannot_propagate:
  803. continue
  804. if is_mkldnn_conv(n):
  805. output_set.add(n)
  806. continue
  807. for user in n.users:
  808. if user in output_set:
  809. output_set.add(n)
  810. break
  811. # need a second pass to add downstream nodes of those channel last nodes to the sets.
  812. # This pass is especially needed to avoid mix-layout kernel inputs in backward pass.
  813. #
  814. # Let's say a conv-batchnorm 's output is passed to relu whose output is in turn returned
  815. # from the fwd graph. Without this second pass, we will force relu's output to be contiguous.
  816. # Then in the kernel in backward pass, the contiguous output of relu may be mix with other channels last
  817. # tensors and passed to a kernel.
  818. #
  819. # This pass improve yolov3 training speedup from 1.116x (worse than disabling layout optimization speedup 1.196x) to 1.457x.
  820. # It also improves dla102 training speedup from 1.240x (worse than disabling layout optimization speedup 1.523x) to 1.835x .
  821. # This also helps the following models:
  822. # - res2net101_26w_4s
  823. # - res2net50_14w_8s
  824. # - sebotnet33ts_256
  825. for n in self.module.graph.nodes: # type: ignore[union-attr]
  826. # layout propagation ends at last conv node, which will benefit vison transformers.
  827. if last_conv is not None and n == last_conv:
  828. break
  829. if n in output_set:
  830. for user in n.users:
  831. if user.target in nodes_cannot_propagate:
  832. continue
  833. output_set.add(user)
  834. return output_set
  835. def warn_fallback(self, name: str) -> None:
  836. if name not in self._warned_fallback:
  837. self._warned_fallback.add(name)
  838. perf_hint_log.info("Using FallbackKernel: %s", name)
  839. def add_device_info(self, device: torch.device) -> None:
  840. self.device_types.add(device.type)
  841. if device.index is not None:
  842. self.device_idxs.add(device.index)
  843. if V.graph.current_node and device not in self.device_node_mapping:
  844. self.device_node_mapping[device] = V.graph.current_node
  845. @property
  846. def fake_mode(self) -> torch._subclasses.fake_tensor.FakeTensorMode:
  847. return V.fake_mode
  848. def try_get_buffer(
  849. self, buffer_name: str
  850. ) -> Optional[Union[ir.TensorBox, ir.Buffer, ir.TorchBindObject]]:
  851. if buffer_name in self.name_to_buffer:
  852. return self.name_to_buffer[buffer_name]
  853. if buffer_name in self.graph_inputs:
  854. return self.graph_inputs[buffer_name]
  855. if buffer_name in self.constants:
  856. data = V.graph.constants[buffer_name]
  857. return ir.ConstantBuffer(
  858. name=buffer_name,
  859. layout=ir.FixedLayout(
  860. data.device, data.dtype, *V.graph.static_sizes_strides(data)
  861. ),
  862. )
  863. return None
  864. def add_symbol_graph_input(self, symbol: sympy.Expr) -> None:
  865. raise RuntimeError("Should not be called for the main graph")
  866. def get_buffer(
  867. self, buffer_name: str
  868. ) -> Union[ir.TensorBox, ir.Buffer, ir.TorchBindObject]:
  869. buf = self.try_get_buffer(buffer_name)
  870. if buf is not None:
  871. return buf
  872. raise RuntimeError(f"Failed to find buffer matching name {buffer_name}")
  873. def get_dtype(self, buffer_name: str) -> torch.dtype:
  874. if buffer_name in self.constants:
  875. return self.constants[buffer_name].dtype
  876. # For a mutation op we should return the dtype of the buffer being mutated
  877. if (
  878. hasattr(self.scheduler, "mutation_real_name")
  879. and buffer_name in self.scheduler.mutation_real_name
  880. ):
  881. mutated_buf = self.scheduler.mutation_real_name[buffer_name]
  882. if mutated_buf in self.name_to_buffer:
  883. return self.name_to_buffer[mutated_buf].get_dtype()
  884. if mutated_buf in self.graph_inputs:
  885. return self.graph_inputs[mutated_buf].get_dtype()
  886. if buffer_name in self.name_to_buffer:
  887. return self.name_to_buffer[buffer_name].get_dtype()
  888. if buffer_name in self.graph_inputs:
  889. return self.graph_inputs[buffer_name].get_dtype()
  890. m = re.match(r"(as_strided|reinterpret_tensor)\(([a-zA-Z0-9_]+),", buffer_name)
  891. if m:
  892. return self.get_dtype(m.group(1))
  893. raise KeyError(f"could not find {buffer_name}")
  894. def get_numel(self, buffer_name: str) -> Union[int, Expr]:
  895. if buffer_name in self.constants:
  896. return self.constants[buffer_name].numel()
  897. if buffer_name in self.name_to_buffer:
  898. buf = self.name_to_buffer[buffer_name]
  899. if not buf.has_tensor_output():
  900. return 1
  901. return buf.get_numel()
  902. if buffer_name in self.graph_inputs:
  903. return self.graph_inputs[buffer_name].get_numel()
  904. raise KeyError(f"could not find {buffer_name}")
  905. def run(self, *args: Any) -> Any: # type: ignore[override]
  906. with dynamo_timed("GraphLowering.run"):
  907. return super().run(*args)
  908. def register_operation(self, op: ir.Operation) -> str:
  909. assert op.operation_name is None, f"Operation registered twice: {op}"
  910. assert isinstance(op, ir.Operation)
  911. name = self.qualify_name(f"op{len(self.operations)}")
  912. self.operations.append(op)
  913. self.name_to_op[name] = op
  914. op.operation_name = name
  915. return name
  916. def register_buffer(self, buffer: ir.Buffer, *, set_name: bool = False) -> str:
  917. name = self.qualify_name(f"buf{len(self.buffers)}")
  918. self.buffers.append(buffer)
  919. self.name_to_buffer[name] = buffer
  920. device = buffer.get_device()
  921. if (
  922. # Skip empty CPU tensor so that CUDA graphs can succeed, see https://github.com/pytorch/pytorch/pull/114144
  923. device is not None
  924. and not (
  925. isinstance(buffer, ir.ComputedBuffer)
  926. and buffer.is_zero_elements()
  927. and device == torch.device("cpu")
  928. )
  929. ):
  930. self.add_device_info(device)
  931. if set_name:
  932. buffer.name = name
  933. return name
  934. def register_operation_list(self, operation_names: list[str]) -> str:
  935. name = self.qualify_name("list_" + "_".join(operation_names))
  936. self.lists[name] = operation_names
  937. return name
  938. def register_users_of(
  939. self, node_output: Union[Iterable[ir.IRNode], ir.IRNode]
  940. ) -> None:
  941. def register(value: Union[Iterable[ir.IRNode], ir.IRNode]) -> None:
  942. if isinstance(value, (list, tuple)):
  943. for x in value:
  944. register(x)
  945. if isinstance(value, ir.TensorBox):
  946. for read_name in value.get_read_names():
  947. self.name_to_users[read_name].append(value)
  948. register(node_output)
  949. def mark_buffer_mutated(self, name: str) -> None:
  950. """
  951. When a buffer is mutated we need to make sure all the reads to
  952. the old version are realized before the mutation happens.
  953. """
  954. assert isinstance(name, str)
  955. self.mutated_buffers.add(name)
  956. if name not in self.name_to_users:
  957. return
  958. for user in self.name_to_users[name]:
  959. user.realize()
  960. def get_original_value_of_constant(self, name: str) -> torch.Tensor:
  961. """
  962. In AOTI, module buffers may have been mutated during the tracing and compilation.
  963. Thus we need to read from previously stored original buffers, to make sure the
  964. generated model.so uses correct initial values.
  965. """
  966. assert name in self.allocated_constant_name and name in self.constants, (
  967. "Can not find the original value for " + name
  968. )
  969. orig_name = get_cloned_parameter_buffer_name(self.allocated_constant_name[name])
  970. return (
  971. self.module.meta[orig_name] # type: ignore[index]
  972. if orig_name in self.module.meta # type: ignore[operator]
  973. else self.constants[name]
  974. )
  975. def allocate_non_dup_const_name(
  976. self, name: Optional[str], data: Union[Tensor]
  977. ) -> str:
  978. if not config.aot_inductor.use_runtime_constant_folding:
  979. for constant_name, value in self.constants.items():
  980. if is_same_tensor(data, value):
  981. return constant_name
  982. if name is None:
  983. name = f"constant{len(self.constants)}"
  984. orig_name = name
  985. if name[0].isdigit():
  986. name = f"constant_{name}"
  987. name = self.qualify_name(name)
  988. # We may generate a var name for each constant in the codegen.
  989. # Let's only keep sane characters.
  990. prefix = normalize_name(name)
  991. name = prefix
  992. cnt = 0
  993. while name in self.constants:
  994. name = f"{prefix}_{cnt}"
  995. cnt += 1
  996. self.constants[name] = data
  997. self.constant_reprs[name] = (
  998. f"{data.device!r} {data.dtype!r} "
  999. f"{tuple(data.size())!r} {tuple(data.stride())!r} "
  1000. f"{hash(data):x}"
  1001. )
  1002. self.allocated_constant_name[name] = orig_name # type: ignore[assignment]
  1003. return name
  1004. def add_tensor_constant(
  1005. self, data: Tensor, name: Optional[str] = None
  1006. ) -> TensorBox:
  1007. new_name = self.allocate_non_dup_const_name(name, data)
  1008. return TensorBox.create(
  1009. ir.ConstantBuffer(
  1010. name=new_name,
  1011. layout=FixedLayout(
  1012. data.device, data.dtype, *self.static_sizes_strides(data)
  1013. ),
  1014. )
  1015. )
  1016. def constant_name(self, name: str, device_override: Optional[torch.device]) -> str:
  1017. """
  1018. We AOT copy constants to the devices they are needed on.
  1019. If device_override doesn't match the constant's device, then
  1020. copy it and return a different name.
  1021. """
  1022. if self.constants[name].device == device_override or device_override is None:
  1023. return name
  1024. with torch.utils._python_dispatch._disable_current_modes():
  1025. # caller might have OrderedSet fake tensor mode which will create a fake tensor
  1026. # when calling .to, so unset modes here
  1027. non_dup_const_name = self.allocate_non_dup_const_name(
  1028. f"{name}_{device_override.type}{device_override.index or 0}",
  1029. self.constants[name].to(device_override),
  1030. )
  1031. assert non_dup_const_name in self.constants, (
  1032. f"{non_dup_const_name} should be in V.graph.constants already"
  1033. )
  1034. # register device-copied buffers and parameters to graph as well
  1035. # to codegen correct torch::aot_inductor::ConstantType for them rather than `Unknown`
  1036. if any(
  1037. name == normalize_name(buffer_name)
  1038. for buffer_name in self.named_buffers
  1039. ):
  1040. self.named_buffers[non_dup_const_name] = self.constants[
  1041. non_dup_const_name
  1042. ]
  1043. if any(
  1044. name == normalize_name(param_name)
  1045. for param_name in self.named_parameters
  1046. ):
  1047. self.named_parameters[non_dup_const_name] = self.constants[
  1048. non_dup_const_name
  1049. ]
  1050. return non_dup_const_name
  1051. # pyrefly: ignore [bad-override]
  1052. def placeholder(
  1053. self,
  1054. target: str, # type: ignore[override]
  1055. args: tuple[object], # type: ignore[override]
  1056. kwargs: dict[str, object],
  1057. ) -> Union[Expr, TensorBox, None]:
  1058. self.placeholder_idx += 1
  1059. example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
  1060. target = self.qualify_name(target)
  1061. if isinstance(example, SymTypes):
  1062. # TODO fix partitioning issue and re-enable for backward
  1063. # https://github.com/pytorch/pytorch/issues/155468.
  1064. if not V.graph.is_backward:
  1065. expr = _get_placeholder_expr(example.node)
  1066. else:
  1067. expr = example.node.expr
  1068. self.graph_inputs[target] = expr
  1069. self.graph_input_names.append(target)
  1070. return expr
  1071. elif isinstance(example, (int, bool, float)):
  1072. expr = sympy.sympify(example)
  1073. self.graph_inputs[target] = expr
  1074. self.graph_input_names.append(target)
  1075. return expr
  1076. elif isinstance(example, FakeScriptObject):
  1077. obj = TorchBindObject(name=target, value=example)
  1078. self.graph_inputs[target] = obj
  1079. self.graph_input_names.append(target)
  1080. return obj
  1081. elif example is None:
  1082. self.graph_input_names.append(target)
  1083. return None
  1084. if isinstance(example, BackwardState):
  1085. # Ignored arg, must be unused
  1086. # Alternately we could filter this out in AotAutograd
  1087. self.graph_input_names.append(target)
  1088. return None
  1089. # See note: Note: [Generator arguments in AOTDispatcher]
  1090. elif isinstance(example, torch.Generator):
  1091. assert len(V.graph.current_node.users) == 1 and next(
  1092. iter(V.graph.current_node.users)
  1093. ).target in (
  1094. torch._prims.rng_prims.graphsafe_run_with_rng_state,
  1095. torch.ops.higher_order.invoke_subgraph,
  1096. )
  1097. gen = ir.GeneratorState(name=target, device=example.device)
  1098. self.graph_inputs[target] = gen # type: ignore[assignment]
  1099. self.graph_input_names.append(target)
  1100. return gen
  1101. assert isinstance(example, torch.Tensor), example
  1102. # todo(chilli): We can remove the last check once we turn buffers into
  1103. # static shape tensors. That's a hack to workaround Inductor believing
  1104. # the buffer should be static but us passing in a fake tensor with
  1105. # symbolic shapes.
  1106. if not example._has_symbolic_sizes_strides:
  1107. # the first N inputs are weights
  1108. sizes, strides = self.static_sizes_strides(example)
  1109. else:
  1110. sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment]
  1111. if (
  1112. self.is_backward
  1113. and self.bw_donated_idxs
  1114. and self.placeholder_idx in self.bw_donated_idxs
  1115. ):
  1116. tensor = TensorBox.create(
  1117. DonatedBuffer(
  1118. name=target,
  1119. layout=FixedLayout(example.device, example.dtype, sizes, strides),
  1120. )
  1121. )
  1122. else:
  1123. # TODO(jansel): handle input aliasing
  1124. tensor = TensorBox.create(
  1125. InputBuffer(
  1126. name=target,
  1127. layout=FixedLayout(example.device, example.dtype, sizes, strides),
  1128. )
  1129. )
  1130. self.graph_inputs[target] = tensor
  1131. self.graph_input_names.append(target)
  1132. self.graph_inputs_original[target] = tensor.data.data # type: ignore[union-attr]
  1133. if self.current_node.users: # cudagraphs should work with an unused CPU input
  1134. self.add_device_info(example.device)
  1135. # Note: [Input Alignment handling in Inductor]
  1136. # Alignment matters for generating efficient code. Some operations,
  1137. # e.g. vectorized loads, can only be performed on aligned inputs.
  1138. #
  1139. # But if we codegen assuming aligned inputs and then get unaligned
  1140. # inputs at runtime, then we are forced to clone - which is bad for
  1141. # both perf and memory usage.
  1142. #
  1143. # One option would be to guard on storage_offset%ALIGNMENT, and then
  1144. # codegen based on this. But storage_offset guards turned out to be
  1145. # expensive and cause recompiles; Instead, we're generating code
  1146. # based on the alignment of the example input without guarding.
  1147. with maybe_get_suppress_shape_guards_ctx():
  1148. if not should_assume_input_aligned(example):
  1149. self.unaligned_buffers.add(target)
  1150. return tensor
  1151. def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override]
  1152. if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
  1153. return super().call_function(target, args, kwargs)
  1154. # hasattr on OpOverloadPacket is slow, check isinstance first
  1155. if not isinstance(target, torch._ops.OpOverloadPacket) and hasattr(
  1156. target, "_inductor_lowering_function"
  1157. ):
  1158. # passthrough lowerings from .pattern_matcher
  1159. return target(*args, **kwargs)
  1160. if target not in lowerings:
  1161. assert isinstance(target, torch._ops.OpOverload), (
  1162. f"{target} is not an OpOverload"
  1163. )
  1164. base_name = target.name().split(".")[0]
  1165. if base_name in FALLBACK_ALLOW_LIST:
  1166. make_fallback(target, warn=False, override_decomp=True)
  1167. elif config.implicit_fallbacks:
  1168. error = (
  1169. MissingOperatorWithDecomp
  1170. if get_decompositions([target])
  1171. else MissingOperatorWithoutDecomp
  1172. )
  1173. log.info(
  1174. "Creating implicit fallback for:\n%s",
  1175. error.operator_str(target, args, kwargs),
  1176. )
  1177. tag: Optional[torch._C.Tag] = get_layout_constraint_tag(
  1178. target, with_default=False
  1179. )
  1180. if (
  1181. tag is None
  1182. and torch._library.utils.is_builtin(target)
  1183. and self.is_backward
  1184. ):
  1185. # for implicit fallback ATen ops during backward, if there
  1186. # is no layout constraint tag, we conservatively require contiguous
  1187. # input since some eager kernels do not
  1188. # support non-contiguous inputs. Otherwise they may silently cause
  1189. # accuracy problems. Check https://github.com/pytorch/pytorch/issues/140452
  1190. # We only do this For ATen ops and for backward.
  1191. #
  1192. # TODO: should really switch to "needs_fixed_stride" constraint on these
  1193. # and identify them one by one.
  1194. decided_constraint: Optional[Callable[..., tuple[Any, Any]]] = (
  1195. require_contiguous
  1196. )
  1197. else:
  1198. default_tag: torch._C.Tag = get_layout_constraint_tag(
  1199. target, with_default=True
  1200. )
  1201. decided_constraint = tag_to_layout_constraint(default_tag)
  1202. make_fallback(target, layout_constraint=decided_constraint)
  1203. elif get_decompositions([target]):
  1204. # There isn't a good way to dynamically patch this in
  1205. # since AOT Autograd already ran. The error message tells
  1206. # the user how to fix it.
  1207. raise MissingOperatorWithDecomp(target, args, kwargs)
  1208. else:
  1209. raise MissingOperatorWithoutDecomp(target, args, kwargs)
  1210. try:
  1211. log.debug(" via %s", lowerings[target]) # type: ignore[index]
  1212. n = self.current_node
  1213. layout_constraints = maybe_layout_constraints(target)
  1214. if layout_constraints:
  1215. old_args, old_kwargs = args, kwargs
  1216. if layout_constraints is constrain_to_fake_tensors:
  1217. # only constrain_to_fake_tensor if this exists.
  1218. # otherwise, no constraints at all: the implication is
  1219. # that this operator was inserted by a custom pass
  1220. # so we'll give them the freedom.
  1221. if "eager_input_vals" in n.meta:
  1222. fake_args, fake_kwargs = n.meta["eager_input_vals"]
  1223. # (fake_args, fake_kwargs) might not align with (args, kwargs).
  1224. # we need to normalize them based on the schema
  1225. assert isinstance(target, torch._ops.OpOverload)
  1226. def normalize(args: Any, kwargs: Any) -> tuple[Any, Any]:
  1227. result = torch.fx.operator_schemas.normalize_function(
  1228. target, args, kwargs
  1229. )
  1230. assert result is not None
  1231. return result[0], result[1]
  1232. fake_args, fake_kwargs = normalize(fake_args, fake_kwargs)
  1233. args, kwargs = normalize(args, kwargs)
  1234. old_args, old_kwargs = normalize(old_args, old_kwargs)
  1235. args, kwargs = constrain_to_fake_tensors(
  1236. args, kwargs, fake_args, fake_kwargs
  1237. )
  1238. else:
  1239. args, kwargs = layout_constraints(n, *args, **kwargs)
  1240. if "should_fallback" in n.meta:
  1241. out = fallback_handler(target, add_to_fallback_set=False)(
  1242. *args, **kwargs
  1243. )
  1244. else:
  1245. out = lowerings[target](*args, **kwargs) # type: ignore[index]
  1246. if layout_constraints:
  1247. # layout_constraints are allowed to make new copies of the inputs.
  1248. # if they do, and if the target is mutable, then we need to
  1249. # write the new values back into the original inputs.
  1250. self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined]
  1251. return out
  1252. except Exception as e:
  1253. stack_trace = None
  1254. if (
  1255. hasattr(self, "current_node")
  1256. and self.current_node is not None
  1257. and hasattr(self.current_node, "meta")
  1258. and self.current_node.meta is not None
  1259. ):
  1260. stack_trace = self.current_node.meta.get("stack_trace", None)
  1261. raise LoweringException(
  1262. e, target, args, kwargs, stack_trace=stack_trace
  1263. ).with_traceback(e.__traceback__) from None
  1264. @staticmethod
  1265. def can_inline_constant(t: torch.Tensor) -> bool:
  1266. """
  1267. True if this is a small constant attr that will be inlined.
  1268. """
  1269. return len(t.shape) == 1 and t.shape[0] <= 8
  1270. # pyrefly: ignore [bad-override]
  1271. def get_attr(
  1272. self,
  1273. target: str, # type: ignore[override]
  1274. args: tuple[()], # type: ignore[override]
  1275. kwargs: dict[str, object],
  1276. ) -> Union[
  1277. Constant, TensorBox, ShapeAsConstantBuffer, ir.Subgraph, TorchBindObject
  1278. ]:
  1279. # this is a constant
  1280. value = getattr_recursive(self.module, target) # type: ignore[arg-type]
  1281. if isinstance(value, torch.fx.GraphModule):
  1282. # Reuse the existing subgraph if we have seen it before already.
  1283. if target in self.seen_subgraphs:
  1284. return self.seen_subgraphs[target]
  1285. out = ir.Subgraph(name=target, graph_module=value)
  1286. self.seen_subgraphs[target] = out
  1287. return out
  1288. if isinstance(value, torch._C.ScriptObject):
  1289. self.torchbind_constants[target] = value
  1290. self.constant_reprs[target] = ""
  1291. return TorchBindObject(name=target, value=value)
  1292. elif isinstance(value, FakeScriptObject):
  1293. self.torchbind_constants[target] = value
  1294. self.constant_reprs[target] = ""
  1295. return TorchBindObject(name=target, value=value)
  1296. elif is_opaque_type(type(value)):
  1297. self.torchbind_constants[target] = value # type: ignore[arg-type]
  1298. self.constant_reprs[target] = ""
  1299. return TorchBindObject(name=target, value=value) # type: ignore[arg-type]
  1300. assert isinstance(value, torch.Tensor)
  1301. if (
  1302. config.aot_inductor.use_runtime_constant_folding
  1303. or config.always_keep_tensor_constants
  1304. or unsupported_output_tensor(value)
  1305. or target in self.mutated_named_buffers
  1306. ):
  1307. return self.add_tensor_constant(value, target)
  1308. with no_dispatch():
  1309. if value.shape == ():
  1310. return Constant(
  1311. value=value.item(), dtype=value.dtype, device=value.device
  1312. )
  1313. if self.can_inline_constant(value):
  1314. log.debug("Inlining constant: %s ", str(target))
  1315. # tensor lowering has constant inlining logic
  1316. from .lowering import tensor
  1317. return tensor(value.tolist(), dtype=value.dtype, device=value.device)
  1318. return self.add_tensor_constant(value, target)
  1319. def call_module(self, target: Any, args: Any, kwargs: Any) -> NoReturn:
  1320. raise AssertionError
  1321. def call_method(self, target: Any, args: Any, kwargs: Any) -> NoReturn:
  1322. raise AssertionError
  1323. # pyrefly: ignore [bad-override]
  1324. def output(
  1325. self,
  1326. target: str, # type: ignore[override]
  1327. args: tuple[object], # type: ignore[override]
  1328. kwargs: dict[str, object],
  1329. ) -> None:
  1330. result = super().output(target, args, kwargs) # type: ignore[arg-type]
  1331. if not isinstance(result, (tuple, list)):
  1332. # nested subgraphs can have singleton outputs
  1333. result = (result,)
  1334. assert isinstance(result, (tuple, list)), type(result)
  1335. assert all(
  1336. isinstance(
  1337. x,
  1338. (
  1339. TensorBox,
  1340. ir.Constant,
  1341. type(None),
  1342. ir.ConstantBuffer,
  1343. sympy.Expr,
  1344. sympy.logic.boolalg.Boolean,
  1345. int,
  1346. ir.EffectfulKernel,
  1347. ir.ShapeAsConstantBuffer,
  1348. TorchBindObject,
  1349. ),
  1350. )
  1351. for x in result
  1352. ), result
  1353. fx_node_args = V.graph.current_node.args[0] # type: ignore[arg-type]
  1354. if not isinstance(fx_node_args, (tuple, list)):
  1355. # nested subgraphs can have singleton outputs
  1356. fx_node_args = (fx_node_args,)
  1357. result = [ir.ExternKernel.realize_input(x) for x in result]
  1358. result_correct_strides = []
  1359. assert len(fx_node_args) == len(result)
  1360. for r, fx_node in zip(result, fx_node_args):
  1361. if not isinstance(r, (ir.TensorBox, ir.BaseView)):
  1362. result_correct_strides.append(r)
  1363. elif isinstance(r.get_output_spec(), ir.CommBufferLayout):
  1364. # Active references to persistent comm buffers are not allowed
  1365. # outside of graphs
  1366. result_correct_strides.append(ir.ExternKernel.copy_input(r))
  1367. else:
  1368. # AOT Autograd tries to detect stride divergence of inductor from output metadata.
  1369. # Here, we try to avoid spurious divergence by matching insignificant strides such as
  1370. # should have already been realized
  1371. assert torch._inductor.ir.is_storage_and_layout(r)
  1372. meta_strides = [
  1373. s.node.expr if isinstance(s, torch.SymInt) else s
  1374. # pyrefly: ignore [missing-attribute]
  1375. for s in fx_node.meta["val"].stride()
  1376. ]
  1377. result_correct_strides.append(
  1378. ir.try_match_insignificant_strides(r, meta_strides)
  1379. )
  1380. self.graph_outputs = result_correct_strides
  1381. value: ir.IRNode
  1382. for name, value in self.graph_inputs.items():
  1383. if isinstance(value, TorchBindObject):
  1384. continue
  1385. assert isinstance(
  1386. value, (TensorBox, sympy.Expr, torch._inductor.ir.GeneratorState)
  1387. ), f"Unsupported inductor graph input type: {type(value)}"
  1388. if not isinstance(value, TensorBox):
  1389. continue
  1390. value.realize()
  1391. assert isinstance(value, TensorBox)
  1392. value = value.data
  1393. assert isinstance(value, ir.StorageBox)
  1394. value_storage_box = value
  1395. value = value.data
  1396. if not isinstance(value, InputBuffer) or value.get_name() != name:
  1397. # one of our inputs was mutated, need to turn that into a copy
  1398. ir.MutationLayoutSHOULDREMOVE.realize_into(
  1399. value, self.graph_inputs_original[name]
  1400. )
  1401. # replace output with mutated input
  1402. try:
  1403. ind = self.graph_outputs.index(value_storage_box)
  1404. self.graph_outputs[ind] = self.graph_inputs_original[name]
  1405. except ValueError:
  1406. pass
  1407. self.finalize()
  1408. log.debug(
  1409. "Force channels last inputs for %d conv for the current graph with id %d",
  1410. self.num_channels_last_conv,
  1411. self.graph_id if self.graph_id is not None else -1,
  1412. )
  1413. def finalize(self) -> None:
  1414. for buf in self.buffers:
  1415. buf.decide_layout()
  1416. @contextmanager
  1417. def set_current_node(self, node: torch.fx.Node): # type: ignore[no-untyped-def]
  1418. old = self.current_node
  1419. try:
  1420. self.current_node = node
  1421. yield
  1422. finally:
  1423. self.current_node = old
  1424. @contextmanager
  1425. def set_current_wrapper_code(self) -> Iterator[None]:
  1426. old = self.wrapper_code
  1427. try:
  1428. yield
  1429. finally:
  1430. self.wrapper_code = old
  1431. def propagate_mutation(
  1432. self,
  1433. fx_node: torch.fx.Node,
  1434. old_args: tuple[Any],
  1435. old_kwargs: dict[str, Any],
  1436. new_args: tuple[Any],
  1437. new_kwargs: dict[str, Any],
  1438. ) -> None:
  1439. """Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs.
  1440. Assumes we may have cloned old_args/old_kwargs into new_args/new_kwargs
  1441. and then called fx_node(*new_args, **new_kwargs).
  1442. If fx_node mutates any of new_args/new_kwargs, and they are different from
  1443. old_args/old_kwargs, then we need to update the original tensor.
  1444. """
  1445. assert len(old_args) == len(new_args)
  1446. assert len(old_kwargs) == len(new_kwargs)
  1447. if fx_node.target is torch.ops.higher_order.triton_kernel_wrapper_mutation:
  1448. kwargs = fx_node.kwargs["kwargs"]
  1449. assert isinstance(kwargs, dict)
  1450. mutated = torch._higher_order_ops.triton_kernel_wrap.get_mutated_tensors(
  1451. old_kwargs["kernel_idx"],
  1452. old_kwargs["constant_args_idx"],
  1453. {
  1454. k: v.meta["val"] if isinstance(v, torch.fx.Node) else v
  1455. for k, v in kwargs.items()
  1456. },
  1457. old_kwargs["tma_descriptor_metadata"],
  1458. )
  1459. for name in mutated:
  1460. old_arg = old_kwargs["kwargs"][name]
  1461. new_arg = new_kwargs["kwargs"][name]
  1462. if old_arg is new_arg:
  1463. continue
  1464. self.call_function(torch.ops.aten.copy_.default, (old_arg, new_arg), {})
  1465. return
  1466. assert isinstance(fx_node.target, torch._ops.OpOverload)
  1467. def maybe_propagate(
  1468. schema_arg: torch._C.Argument, old_arg: ir.IRNode, new_arg: ir.IRNode
  1469. ) -> None:
  1470. if old_arg is new_arg:
  1471. return
  1472. if schema_arg.alias_info is not None and schema_arg.alias_info.is_write:
  1473. # The lowering for copy_ is smart enough to "replace" old_arg with
  1474. # new_arg in all future uses so a copy_ kernel never gets emitted.
  1475. # old_arg, new_arg may be immutable_list
  1476. if isinstance(old_arg, ir.IRNode):
  1477. old_arg = (old_arg,) # type: ignore[assignment]
  1478. new_arg = (new_arg,) # type: ignore[assignment]
  1479. for old_arg_item, new_arg_item in zip(old_arg, new_arg): # type: ignore[call-overload]
  1480. if old_arg_item is new_arg_item:
  1481. continue
  1482. self.call_function(
  1483. torch.ops.aten.copy_.default, (old_arg_item, new_arg_item), {}
  1484. )
  1485. schema = fx_node.target._schema
  1486. for idx, (old_arg, new_arg) in enumerate(zip(old_args, new_args)):
  1487. schema_arg = schema.arguments[idx]
  1488. maybe_propagate(schema_arg, old_arg, new_arg)
  1489. schema_kwargs = {arg.name: arg for arg in schema.arguments}
  1490. for key in old_kwargs:
  1491. old_arg = old_kwargs[key]
  1492. new_arg = new_kwargs[key]
  1493. schema_arg = schema_kwargs[key]
  1494. maybe_propagate(schema_arg, old_arg, new_arg)
  1495. def run_node(self, n: torch.fx.Node) -> object:
  1496. """Lower and execute a single FX node into Inductor IR."""
  1497. def debug(msg: str) -> None:
  1498. log.debug("lowering %s %s", LazyString(n.format_node), msg) # type: ignore[arg-type]
  1499. # Use channels-last stride order for certain
  1500. # dense 4D intermediates when layout optimization determines a
  1501. # downstream consumer (typically conv) prefers channels-last.
  1502. def maybe_apply_channels_last_stride_order(
  1503. result: ir.IRNode, n: torch.fx.Node
  1504. ) -> ir.IRNode:
  1505. dense = torch._prims_common.is_non_overlapping_and_dense_or_false(
  1506. n.meta["val"]
  1507. )
  1508. strides = n.meta["val"].stride()
  1509. unbacked_symbols_in_strides = len(free_unbacked_symbols(strides)) > 0
  1510. if (
  1511. not unbacked_symbols_in_strides
  1512. and dense
  1513. and len(result.get_size()) == 4
  1514. and n in self.nodes_prefer_channels_last
  1515. and not is_user_visible
  1516. and not is_input_for_as_strided
  1517. ):
  1518. result = ir.ExternKernel.require_stride_order(
  1519. result,
  1520. ir.get_stride_order(
  1521. make_channels_last_strides_for(n.meta["val"].shape)
  1522. ),
  1523. )
  1524. return result
  1525. from torch._inductor.compiler_bisector import CompilerBisector
  1526. buffer_watermark = len(self.buffers)
  1527. operation_watermark = len(self.operations)
  1528. # origins: OrderedSet[Union[Node, ir.IRNode]] = OrderedSet([n])
  1529. origins: OrderedSet[Any] = OrderedSet([n])
  1530. is_call_function = n.op == "call_function"
  1531. if is_call_function:
  1532. args, kwargs = self.fetch_args_kwargs_from_env(n)
  1533. origins |= gather_origins(args, kwargs)
  1534. with (
  1535. ir.IRNode.current_origins(origins),
  1536. self.set_current_node(n),
  1537. V.set_current_node(n),
  1538. ):
  1539. if (
  1540. n.op == "call_function"
  1541. # this path only for built-in operators
  1542. and n.target
  1543. and isinstance(n.target, torch._ops.OpOverload)
  1544. and torch._library.utils.is_builtin(n.target)
  1545. and (
  1546. fallback_node_due_to_unsupported_type(n)
  1547. or CompilerBisector.disable_subsystem(
  1548. "inductor", "lowerings", lambda: repr(n)
  1549. )
  1550. )
  1551. ):
  1552. debug("fallback_handler")
  1553. result = fallback_handler(n.target, add_to_fallback_set=False)(
  1554. *args, # type: ignore[possibly-undefined]
  1555. **kwargs, # type: ignore[possibly-undefined]
  1556. )
  1557. elif (
  1558. n.op == "call_function"
  1559. and isinstance(
  1560. n.target, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
  1561. )
  1562. and should_fallback_by_default(n)
  1563. ):
  1564. # this path supports fallback due to inductor lite mode. It supports
  1565. # both OpOverload and HOPs (e.g., triton_kernel_wrapper_functional).
  1566. debug("fallback_handler")
  1567. result = fallback_handler(n.target, add_to_fallback_set=False)(
  1568. *args, # type: ignore[possibly-undefined]
  1569. **kwargs, # type: ignore[possibly-undefined]
  1570. )
  1571. elif (
  1572. n.op == "call_function"
  1573. and n.target is torch.ops.higher_order.triton_kernel_wrapper_mutation
  1574. and config.triton_kernel_default_layout_constraint != "flexible_layout"
  1575. ):
  1576. debug("user_defined_triton_kernel_layout_constraints")
  1577. if (
  1578. config.triton_kernel_default_layout_constraint
  1579. == "needs_fixed_stride_order"
  1580. ):
  1581. old_args = args # type: ignore[possibly-undefined]
  1582. old_kwargs = kwargs # type: ignore[possibly-undefined]
  1583. if eager_input_vals := n.meta.get("eager_input_vals"):
  1584. inp_args = eager_input_vals[0]
  1585. inp_kwargs = eager_input_vals[1]
  1586. args, kwargs = constrain_to_fake_tensors(
  1587. # pyrefly: ignore [unbound-name]
  1588. args,
  1589. # pyrefly: ignore [unbound-name]
  1590. kwargs,
  1591. inp_args,
  1592. inp_kwargs,
  1593. )
  1594. else:
  1595. args, kwargs = constrain_to_fx_strides(n, *args, **kwargs) # type: ignore[index]
  1596. result = self.call_function(n.target, args, kwargs) # type: ignore[arg-type]
  1597. self.propagate_mutation(n, old_args, old_kwargs, args, kwargs) # type: ignore[possibly-undefined]
  1598. else:
  1599. raise RuntimeError(
  1600. f"Unknown triton_kernel_default_layout_constraint: {config.triton_kernel_default_layout_constraint}"
  1601. )
  1602. elif is_magic_method(n.target):
  1603. # TODO: this is sus, it probably should be handled in the
  1604. # lowerings themselves similarly to sym_size/sym-stride
  1605. # https://github.com/pytorch/pytorch/issues/127789
  1606. debug("is_magic_method")
  1607. if isinstance(
  1608. n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool)
  1609. ):
  1610. result = n.meta["val"].node.expr
  1611. else:
  1612. result = super().run_node(n)
  1613. else:
  1614. debug("")
  1615. result = super().run_node(n)
  1616. # require the same stride order for dense outputs,
  1617. # 1. user-land view() will not throw because inductor
  1618. # output different strides than eager
  1619. # long term the solution is to make view() always succeed
  1620. # with infallible strides.
  1621. # 2: as_strided ops, we need make sure its input has same size/stride with
  1622. # eager model to align with eager behavior.
  1623. as_strided_ops = [
  1624. torch.ops.aten.as_strided.default,
  1625. torch.ops.aten.as_strided_.default,
  1626. torch.ops.aten.as_strided_scatter.default,
  1627. torch.ops.aten.resize.default,
  1628. torch.ops.aten.resize_as.default,
  1629. ]
  1630. is_output = any(user.op == "output" for user in n.users)
  1631. is_user_visible = n in self.user_visible_output_strides
  1632. is_input_for_as_strided = any(
  1633. user.target in as_strided_ops for user in n.users
  1634. )
  1635. if n.meta.get("inductor_realize_to_strides", False) and isinstance(
  1636. result, TensorBox
  1637. ):
  1638. result.realize()
  1639. strides = n.meta["val"].stride()
  1640. sym_strides = torch._inductor.utils.any_is_symbolic(*strides)
  1641. if result.maybe_get_stride() != strides and not sym_strides:
  1642. stride_order = ir.get_stride_order(strides)
  1643. result = ir.ExternKernel.require_stride_order(result, stride_order)
  1644. if (
  1645. is_output
  1646. and isinstance(result, TensorBox)
  1647. and isinstance(result.data, ir.BaseView)
  1648. ):
  1649. # Realize so that outputs are correctly aliased
  1650. result.realize()
  1651. if (is_output or is_input_for_as_strided) and isinstance(
  1652. n.meta["val"], torch.Tensor
  1653. ):
  1654. if is_user_visible:
  1655. strides = self.user_visible_output_strides.get(n)
  1656. else:
  1657. strides = n.meta["val"].stride()
  1658. if strides is not None and len(strides) > 0:
  1659. allow_padding = (
  1660. config.pad_outputs or not is_user_visible
  1661. ) and not is_input_for_as_strided
  1662. dense = torch._prims_common.is_non_overlapping_and_dense_or_false(
  1663. n.meta["val"]
  1664. )
  1665. unbacked_symbols_in_strides = (
  1666. len(free_unbacked_symbols(strides)) > 0
  1667. )
  1668. if (
  1669. not unbacked_symbols_in_strides
  1670. and dense
  1671. and len(result.get_size()) == 4
  1672. and n in self.nodes_prefer_channels_last
  1673. and not is_user_visible
  1674. and not is_input_for_as_strided
  1675. ):
  1676. strides = ir.FlexibleLayout.stride_ordered_for_memory_format(
  1677. result.get_size(), torch.channels_last
  1678. )
  1679. if not unbacked_symbols_in_strides and len(strides):
  1680. # To avoid converting possible view ops to a copy kernel, we use the previous
  1681. # require_exact_strides to handle views. But ultimately it's better to require
  1682. # the right strides at the tensor definition.
  1683. if n.meta["val"]._is_view() or isinstance(
  1684. result.data,
  1685. ir.BaseView,
  1686. ):
  1687. result = ir.ExternKernel.require_stride_order(
  1688. result,
  1689. ir.get_stride_order(strides),
  1690. allow_padding=allow_padding,
  1691. )
  1692. else:
  1693. # Fix for 0-d tensors: if result size is empty,
  1694. # strides should also be empty
  1695. if len(result.get_size()) == 0 and len(strides) > 0:
  1696. strides = []
  1697. result = ir.ExternKernel.require_exact_strides(
  1698. result, strides, allow_padding=allow_padding
  1699. )
  1700. # Realize if (1) any user need inputs realized, or (2) there is
  1701. # already too many reads and rematerializing can be bad.
  1702. num_users = len(OrderedSet(n.users))
  1703. if num_users > 1 and isinstance(result, TensorBox):
  1704. for user in n.users:
  1705. if user.target in needs_realized_inputs:
  1706. result.realize_hint()
  1707. # This inclusion is somewhat controversial (from
  1708. # discussion between Horace, Natalia, and Elias).
  1709. # Currently, it's not very clear why this is helpful.
  1710. # The general idea here is that even though a node may
  1711. # have FlexibleLayout, we still often *treat* it as if
  1712. # it was contiguous. This appears to sometimes result in
  1713. # suboptimal behavior.
  1714. #
  1715. # When we do a better job selecting layout, we should
  1716. # revisit this.
  1717. need_fixed_layout = [
  1718. torch.ops.aten.convolution_backward.default,
  1719. torch.ops.aten.mm.default,
  1720. torch.ops.aten._int_mm.default,
  1721. ]
  1722. need_fixed_channels_last_layout = []
  1723. if not self.layout_opt:
  1724. need_fixed_layout.append(torch.ops.aten.convolution.default)
  1725. if torch._C._has_mkldnn:
  1726. need_fixed_layout += [
  1727. torch.ops.mkldnn._linear_pointwise.default,
  1728. torch.ops.mkldnn._linear_pointwise.binary,
  1729. torch.ops.aten.mkldnn_rnn_layer.default,
  1730. torch.ops.onednn.qlinear_pointwise.default,
  1731. torch.ops.onednn.qlinear_pointwise.tensor,
  1732. torch.ops.onednn.qlinear_pointwise.binary,
  1733. torch.ops.onednn.qlinear_pointwise.binary_tensor,
  1734. ]
  1735. need_fixed_channels_last_layout += [
  1736. torch.ops.mkldnn._convolution_pointwise.default,
  1737. torch.ops.mkldnn._convolution_pointwise.binary,
  1738. torch.ops.mkldnn._convolution_pointwise_.binary,
  1739. torch.ops.mkldnn._convolution_transpose_pointwise.default,
  1740. torch.ops.onednn.qconv_pointwise.default,
  1741. torch.ops.onednn.qconv2d_pointwise.binary,
  1742. ]
  1743. if torch._C.has_mkl:
  1744. need_fixed_layout += [torch.ops.mkl._mkl_linear.default]
  1745. if user.target in need_fixed_layout:
  1746. result = ir.ExternKernel.require_stride_order(
  1747. result,
  1748. ir.get_stride_order(n.meta["val"].stride()),
  1749. allow_padding=True,
  1750. )
  1751. if (
  1752. user.target in need_fixed_channels_last_layout
  1753. and n is user.args[0]
  1754. ):
  1755. result = ir.ExternKernel.require_stride_order(
  1756. result,
  1757. ir.get_stride_order(
  1758. make_channels_last_strides_for(n.meta["val"].shape)
  1759. ),
  1760. )
  1761. if user.op == "output":
  1762. # pyrefly: ignore [missing-attribute]
  1763. if isinstance(result.data.data, (Pointwise, Reduction)):
  1764. result.realize()
  1765. _data = result.data # type: ignore[attr-defined]
  1766. while not isinstance(_data, StorageBox) and isinstance(
  1767. _data, (ir.BaseView, ir.MutableBox)
  1768. ):
  1769. _data = _data.data
  1770. if isinstance(_data, StorageBox) and _data.should_realize_on_reuse(
  1771. len(n.users)
  1772. ):
  1773. result = maybe_apply_channels_last_stride_order(result, n)
  1774. # TODO(jansel): introduce a store vs inline choice
  1775. result.mark_reuse(len(n.users))
  1776. # Realize if the IRNode already has accumulated lots of reads
  1777. if isinstance(result, TensorBox) and result.has_exceeded_max_reads():
  1778. # Prevent excessive accumulation in a computed buffer, when
  1779. # there are multiple branches each with small number of memory
  1780. # reads, but they converge to a user.
  1781. result = maybe_apply_channels_last_stride_order(result, n)
  1782. result.realize_hint()
  1783. # Realize if a Pointwise has too much stuff to be inlined.
  1784. # As this may cause RecursionError during Inductor's evaluation.
  1785. if isinstance(result, TensorBox) and isinstance(result.data, StorageBox):
  1786. curr = result.data.data
  1787. if isinstance(curr, Pointwise):
  1788. # Use inner fn as a rough proxy. Good enough.
  1789. if curr.has_large_inner_fn(threshold=100):
  1790. result.realize()
  1791. assign_origin_node(result, n)
  1792. self.register_users_of(result)
  1793. new_unbacked_defs = OrderedSet[sympy.Symbol]()
  1794. for buf in self.buffers[buffer_watermark:]:
  1795. new_unbacked_defs |= buf.get_unbacked_symbol_defs()
  1796. for op in self.operations[operation_watermark:]:
  1797. new_unbacked_defs |= op.get_unbacked_symbol_defs()
  1798. shape_env = V.graph.sizevars.shape_env
  1799. # An input can be unbacked symint i.e.: when mark_unbacked is used.
  1800. # in that case add it to new_unbacked_defs.
  1801. if (
  1802. n.op == "placeholder"
  1803. and isinstance(result, sympy.Symbol)
  1804. and shape_env.is_unbacked_symint(result)
  1805. ):
  1806. new_unbacked_defs.add(result)
  1807. def format_new_defs() -> str:
  1808. r = [
  1809. f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n"
  1810. for buf in self.buffers[buffer_watermark:]
  1811. ]
  1812. r.extend(
  1813. f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n"
  1814. for op in self.operations[operation_watermark:]
  1815. )
  1816. return "***\n".join(r)
  1817. # We do not skip unbacked symints that are input for backward see the note below.
  1818. if V.graph.is_backward and n.op == "placeholder":
  1819. return result
  1820. # Note [Backwards runtime asserts]
  1821. # Backwards poses an interesting problem for deferred runtime
  1822. # asserts. In the easy case, we may solely close over data
  1823. # dependent sized tensors, and there are no binding sites for
  1824. # unbacked SymInts. In this case, we can just drop all the
  1825. # runtime asserts on the floor: no non-placeholder bindings, no
  1826. # problem.
  1827. #
  1828. # However, it is *possible* for a fresh runtime assert to show up
  1829. # between forwards and backwards. Right now, the freezing process
  1830. # that happens when we lower forwards means that we will freeze
  1831. # runtime asserts, and then the moment the backwards lowering
  1832. # process attempts to add a new deferred runtime assert, we will
  1833. # fail. Let's say you remove that assert. Now when we get here,
  1834. # we need to make sure we actually emit these asserts (because we
  1835. # can't emit them in forwards, we already compiled it). So we
  1836. # have to do something here. But we don't want to reemit ALL
  1837. # deferred runtime asserts, we only want to emit the NEW ones.
  1838. # Therefore needing some sort of stratification in the ShapeEnv.
  1839. # This is all doable, it just hasn't been done yet.
  1840. unbacked_bindings = resolve_unbacked_bindings(
  1841. V.graph.sizevars.shape_env, n.meta.get("unbacked_bindings", {})
  1842. )
  1843. assert unbacked_bindings is not None
  1844. # When we do lowering, it is possible we reallocate unbacked SymInts.
  1845. # So we need to line up the unbacked SymInts when performing the test
  1846. # here
  1847. #
  1848. # In principle, we could permit lowering to introduce MORE unbacked
  1849. # SymInts: as long as all the old unbacked ones are accounted for,
  1850. # it's fine for inductor to introduce extra calls to item()/unbacked()
  1851. # whatever. This actually happens in practice when an unbacked SymInt
  1852. # gets memoized away; naively, when Inductor reprocesses a kernel, it
  1853. # doesn't know that the memo still applies, and ends up allocating a
  1854. # new symbol. However, this is generally a bad thing: we may still
  1855. # end up needing to test equalities on the symbols, and a fresh
  1856. # symbol is likely to hit lots of GuardOnDataDependent errors that
  1857. # we already know facts for.
  1858. renamed_unbacked_bindings = OrderedSet(
  1859. V.fake_mode.shape_env.unbacked_renamings.get(s, s)
  1860. for s in unbacked_bindings
  1861. )
  1862. assert new_unbacked_defs >= renamed_unbacked_bindings, (
  1863. f"failed {new_unbacked_defs} >= {renamed_unbacked_bindings} (inductor >= fx)\n"
  1864. f"fx node is: {n.format_node()}\n"
  1865. f"new operations are:\n\n{format_new_defs()}"
  1866. )
  1867. self.create_deferred_runtime_asserts(n, new_unbacked_defs)
  1868. return result
  1869. def create_deferred_runtime_asserts(
  1870. self, n: torch.fx.Node, new_unbacked_defs: OrderedSet[sympy.Symbol]
  1871. ) -> None:
  1872. # [NOTE] Codegen runtime asserts in Inductor
  1873. #
  1874. # We need to generate runtime asserts directly in Inductor instead
  1875. # of just reusing the asserts from input graphs because we reuse the
  1876. # same ShapeEnv as before. In particular, on subsequent graph passes,
  1877. # we would immediately turn all of these assertions into noops,
  1878. # because when we evaluated their expressions, we would see that
  1879. # because we had a deferred runtime assert in the ShapeEnv, we
  1880. # know "oh, of course this expression is True" already.
  1881. # One example is below:
  1882. #
  1883. # class Model(torch.nn.Module):
  1884. # def forward(self, a, b, c):
  1885. # nz = torch.nonzero(a)
  1886. # ones = a.new_ones([nz.size(0), b.size(0)])
  1887. # torch._check(ones.size(0) >= 1)
  1888. # equals = torch.add(ones, c)
  1889. # return equals
  1890. # torch._dynamo.mark_dynamic(c, 0)
  1891. # When we reuse the ShapeEnv in Inductor lowering, the check that checks
  1892. # a and nonzero have the same shape would be evaluated to True after we resolve
  1893. # unbacked bindings using the ShapeEnv.
  1894. # See test_unbacked_equals_input_size_runtime_assertion in test_aot_inductor.
  1895. #
  1896. #
  1897. # In addition to the Inductor generated runtime asserts, we also
  1898. # need the runtime asserts from the input graph, because some derived
  1899. # runtime asserts on backed symints are not generated in Inductor. One example is
  1900. # this: `y = x.reshape(100, -1).clone()`. x.shape[0] needs to be a multiple of 100.
  1901. # See test_aoti_runtime_asserts_backed_symint in test_aot_inductor.
  1902. def make_assert(expr: SympyBoolean, msg: str) -> None:
  1903. assert_op = ir.AssertScalar(expr, msg)
  1904. self.register_buffer(assert_op, set_name=True)
  1905. self.register_operation(assert_op)
  1906. if (
  1907. full_aoti_runtime_assert()
  1908. and n.target is torch.ops.aten._assert_scalar.default
  1909. and self.aot_mode
  1910. ):
  1911. node_args, _ = self.fetch_args_kwargs_from_env(n)
  1912. if node_args[0] != True: # noqa: E712
  1913. make_assert(node_args[0], f"{node_args[0]} to be True")
  1914. else:
  1915. # bound_unbacked_symbols tracks the symbols that are created so far,
  1916. # we use it to make sure that runtime assertions are added after all
  1917. # symbols used in them are defined.
  1918. self.bound_unbacked_symbols |= new_unbacked_defs
  1919. shape_env = V.graph.sizevars.shape_env
  1920. # Emit code for runtime asserts that can be inserted at this point.
  1921. for i0 in new_unbacked_defs:
  1922. ras = self.ras_by_symbol.pop(i0, [])
  1923. # NB: size-like not needed, we won't retrace
  1924. vr = shape_env.var_to_range[i0]
  1925. if not shape_env._default_unspecified_value_range().issubset(vr):
  1926. def is_convertible(s: Expr) -> bool:
  1927. if s in (int_oo, -int_oo):
  1928. return False
  1929. try:
  1930. int(s)
  1931. return True
  1932. except TypeError:
  1933. return False
  1934. if is_convertible(vr.lower):
  1935. make_assert(i0 >= vr.lower, f"{i0} >= {vr.lower}")
  1936. if is_convertible(vr.upper):
  1937. make_assert(i0 <= vr.upper, f"{i0} <= {vr.upper}")
  1938. for ra in ras:
  1939. fvs = free_unbacked_symbols(ra.expr)
  1940. missing = fvs - self.bound_unbacked_symbols
  1941. if missing:
  1942. i1 = min(missing, key=str)
  1943. self.ras_by_symbol.setdefault(i1, []).append(ra)
  1944. else:
  1945. make_assert(ra.expr, f"{ra.expr}")
  1946. def validate_can_generate_cpp_wrapper(self) -> None:
  1947. if config.disable_cpp_codegen:
  1948. raise CppWrapperCodegenError("C++ codegen is disabled")
  1949. if sys.platform not in ("linux", "darwin", "win32"):
  1950. raise CppWrapperCodegenError(f"Unsupported platform {sys.platform}")
  1951. def init_wrapper_code(
  1952. self,
  1953. is_subgraph: bool = False,
  1954. subgraph_name: Optional[str] = None,
  1955. parent_wrapper_code: Optional[PythonWrapperCodegen] = None,
  1956. partition_signatures: Optional[GraphPartitionSignature] = None,
  1957. ) -> None:
  1958. device_types = self.device_types.copy()
  1959. device_types.discard("cpu")
  1960. device_types.discard("meta")
  1961. # TODO(Eikan): Only support mixing cpu and other device now.
  1962. assert len(device_types) <= 1, "Does not support mixing {}".format(
  1963. "+".join(device_types)
  1964. )
  1965. only_cpu = len(device_types) == 0
  1966. self.device_type = "cpu" if only_cpu else device_types.pop()
  1967. if self.cpp_wrapper:
  1968. self.validate_can_generate_cpp_wrapper()
  1969. self.device_ops = get_device_op_overrides(self.device_type)
  1970. wrapper_code_gen_cls = get_wrapper_codegen_for_device(
  1971. self.device_type, self.cpp_wrapper, self.fx_wrapper
  1972. )
  1973. assert wrapper_code_gen_cls is not None, (
  1974. f"Device {self.device_type} not supported"
  1975. )
  1976. self.wrapper_code = wrapper_code_gen_cls.create(
  1977. is_subgraph,
  1978. subgraph_name,
  1979. parent_wrapper_code,
  1980. partition_signatures,
  1981. )
  1982. if self.const_module:
  1983. self.wrapper_code._names_iter = self.const_module.wrapper_code._names_iter
  1984. def extract_autotune_inputs(
  1985. self, example_inputs: list[Union[int, float, torch.Tensor]]
  1986. ) -> None:
  1987. import copy
  1988. cloned_gm = copy.deepcopy(self.orig_gm)
  1989. example_inputs = copy.deepcopy(example_inputs)
  1990. triton_nodes = []
  1991. for node in cloned_gm.graph.nodes:
  1992. if (
  1993. node.op == "call_function"
  1994. and node.target is torch.ops.higher_order.triton_kernel_wrapper_mutation
  1995. ):
  1996. triton_nodes.append(node)
  1997. # Store grid related nodes
  1998. grid_inputs: list[torch.fx.Node] = []
  1999. visited_grids: dict[torch.fx.Node, int] = {}
  2000. # Store kwargs related nodes
  2001. triton_inputs: dict[str, Any] = {}
  2002. kwargs_inputs: list[torch.fx.Node] = []
  2003. visited_kwargs: dict[Any, int] = {}
  2004. for node in triton_nodes:
  2005. # first check whether we have fx node in grid settings.
  2006. for grid in node.kwargs["grid"]:
  2007. for val in grid:
  2008. if val in visited_grids:
  2009. continue
  2010. if isinstance(val, torch.fx.Node):
  2011. visited_grids[val] = len(grid_inputs)
  2012. grid_inputs.append(val)
  2013. kwargs = node.kwargs["kwargs"]
  2014. # identify which args might be mutated, those should be cloned.
  2015. mutated = torch._higher_order_ops.triton_kernel_wrap.get_mutated_tensors(
  2016. node.kwargs["kernel_idx"],
  2017. node.kwargs["constant_args_idx"],
  2018. {
  2019. k: v.meta["val"] if isinstance(v, torch.fx.Node) else v
  2020. for k, v in kwargs.items()
  2021. },
  2022. node.kwargs["tma_descriptor_metadata"],
  2023. )
  2024. new_kwargs: dict[str, int] = {}
  2025. with cloned_gm.graph.inserting_before(node):
  2026. for k, v in kwargs.items():
  2027. if k in mutated:
  2028. new_node = cloned_gm.graph.call_function(torch.clone, args=(v,))
  2029. new_kwargs[k] = len(kwargs_inputs)
  2030. kwargs_inputs.append(new_node)
  2031. continue
  2032. if v in visited_kwargs:
  2033. new_kwargs[k] = visited_kwargs[v]
  2034. continue
  2035. visited_kwargs[v] = len(kwargs_inputs)
  2036. kwargs_inputs.append(v)
  2037. new_kwargs[k] = visited_kwargs[v]
  2038. triton_inputs[node.name] = new_kwargs
  2039. new_outputs = kwargs_inputs + grid_inputs
  2040. for node in cloned_gm.graph.nodes:
  2041. if node.op == "output":
  2042. node.args = (tuple(new_outputs),)
  2043. break
  2044. cloned_gm.recompile()
  2045. runner = torch.fx.Interpreter(cloned_gm)
  2046. returned_outputs = runner.run(example_inputs)
  2047. # Extract and store the grid for autotuning
  2048. if len(grid_inputs) > 0:
  2049. grid_outputs = returned_outputs[len(kwargs_inputs) :]
  2050. self.autotuning_grids = {}
  2051. for node in triton_nodes:
  2052. dynamic_grid = False
  2053. new_grids: list[tuple[Any]] = []
  2054. for grid in node.kwargs["grid"]:
  2055. new_grid = []
  2056. for val in grid:
  2057. if not isinstance(val, torch.fx.Node):
  2058. new_grid.append(val)
  2059. continue
  2060. dynamic_grid = True
  2061. new_grid.append(grid_outputs[visited_grids[val]])
  2062. # pyrefly: ignore [bad-argument-type]
  2063. new_grids.append(tuple(new_grid))
  2064. if dynamic_grid:
  2065. self.autotuning_grids[node.name] = new_grids
  2066. # Store the kwargs input for autotuning
  2067. self.autotuning_inputs = returned_outputs[: len(kwargs_inputs)]
  2068. self.autotuning_mapping = triton_inputs
  2069. def codegen_with_cpp_wrapper(
  2070. self,
  2071. ) -> tuple[ValueWithLineMap, ValueWithLineMap]:
  2072. """
  2073. For GPU, Triton kernels are autotuned and stored as cubin files
  2074. """
  2075. if any(device in self.device_types for device in ["cuda", "xpu"]):
  2076. def extract_real_inputs() -> list[Union[int, float, torch.Tensor]]:
  2077. def materialize(
  2078. x: Union[torch.SymInt, torch.SymFloat, torch.Tensor],
  2079. ) -> Union[int, float, torch.Tensor]:
  2080. if x is None:
  2081. # pyrefly: ignore [bad-return]
  2082. return None
  2083. elif isinstance(x, (torch.SymInt, torch.SymFloat)):
  2084. # Need concrete value to run dynamic shapes and tune the result
  2085. return x.node.hint
  2086. elif isinstance(x, FakeTensor):
  2087. return defake(x)
  2088. else:
  2089. assert isinstance(x, torch.Tensor), (
  2090. "Unknown type when creating real inputs" + str(type(x))
  2091. )
  2092. return x
  2093. tracing_context = torch._guards.TracingContext.try_get()
  2094. if tracing_context is not None and not isinstance(
  2095. V.real_inputs, NullHandler
  2096. ):
  2097. if tracing_context.output_strides:
  2098. tracing_context.output_strides.clear()
  2099. params_flat = [
  2100. param
  2101. for param in tracing_context.params_flat # type: ignore[union-attr]
  2102. if param is not None
  2103. ]
  2104. real_inputs = [
  2105. materialize(x)
  2106. for x in itertools.chain(params_flat, V.real_inputs)
  2107. ]
  2108. else:
  2109. # In the backward pass, V.real_inputs is not OrderedSet.
  2110. # Generating random inputs based on self.example_inputs sometimes can be problematic,
  2111. # e.g. illegal memory access. A comprehensive fix is to autotune in a separate process.
  2112. real_inputs = [
  2113. materialize(x) # type:ignore[arg-type]
  2114. for x in (
  2115. self.example_inputs # type:ignore[union-attr]
  2116. if isinstance(V.real_inputs, NullHandler)
  2117. else V.real_inputs
  2118. )
  2119. ]
  2120. if self.mutated_inputs:
  2121. from .compile_fx import clone_preserve_strides
  2122. mutated_input_idxs = [
  2123. idx
  2124. for idx, name in enumerate(self.graph_inputs)
  2125. if name in self.mutated_inputs
  2126. and isinstance(real_inputs[idx], torch.Tensor)
  2127. ]
  2128. for idx in mutated_input_idxs:
  2129. # clone mutated Tensor inputs to avoid mutating them in
  2130. # the first pass of the CPP wrapper-based compilation, as
  2131. # this will lead to a side effect on the example inputs:
  2132. # e.g. if torch.compile(f)(x) if called on input-mutating
  2133. # f, the inputs x will be mutated twice in the process:
  2134. # once here, and again when running the compiled model;
  2135. # this will also lead to a numerically incorrect output
  2136. mutated_inp = real_inputs[idx]
  2137. assert isinstance(mutated_inp, torch.Tensor)
  2138. real_inputs[idx] = clone_preserve_strides(mutated_inp)
  2139. del mutated_inp
  2140. return real_inputs
  2141. if config.triton.autotune_at_compile_time:
  2142. # If autotune_at_compile_time is True, we can do the codegen in one-pass
  2143. # We will construct the autotuning values if user defined kernel exists.
  2144. if config.triton.autotune_with_sample_inputs:
  2145. user_defined_kernels = False
  2146. for op in self.operations:
  2147. if isinstance(op, ir.UserDefinedTritonKernel):
  2148. user_defined_kernels = True
  2149. break
  2150. if user_defined_kernels:
  2151. real_inputs = extract_real_inputs()
  2152. self.extract_autotune_inputs(real_inputs)
  2153. return self.codegen()
  2154. else:
  2155. # first pass
  2156. self.cpp_wrapper = False
  2157. compiled = self.compile_to_module().call
  2158. real_inputs = extract_real_inputs()
  2159. with torch.utils._python_dispatch._disable_current_modes():
  2160. compiled(real_inputs)
  2161. del real_inputs
  2162. # second pass
  2163. self.cpp_wrapper = True
  2164. self.removed_buffers.clear()
  2165. self.removed_operations.clear()
  2166. self.inplaced_to_remove.clear()
  2167. V.graph.sizevars.precomputed_replacements.clear()
  2168. V.graph.sizevars.inv_precomputed_replacements.clear()
  2169. metrics.reset()
  2170. with config.patch({"triton.autotune_at_compile_time": False}):
  2171. return self.codegen()
  2172. else:
  2173. # cpu
  2174. return self.codegen()
  2175. def _update_scheduler(self) -> None:
  2176. """
  2177. (Re)initializes the scheduler member. When initializing the scheduler, no CUBIN
  2178. files should be generated (to avoid biasing any benchmarks and pessimizing
  2179. fusion decisions).
  2180. """
  2181. from .scheduler import Scheduler
  2182. with config.patch("triton.store_cubin", False):
  2183. self.scheduler = Scheduler(self.operations)
  2184. def codegen(self) -> tuple[ValueWithLineMap, ValueWithLineMap]:
  2185. with dynamo_timed("GraphLowering.codegen", log_pt2_compile_event=True):
  2186. self.init_wrapper_code()
  2187. self._update_scheduler()
  2188. V.debug.draw_orig_fx_graph(self.orig_gm, self.scheduler.nodes)
  2189. self.wrapper_code.push_codegened_graph(self)
  2190. self.scheduler.codegen()
  2191. log.debug(
  2192. "Finished codegen for all nodes. The list of kernel names available: %s",
  2193. V.graph.all_codegen_kernel_names,
  2194. )
  2195. result = self.wrapper_code.generate(self.is_inference)
  2196. self.wrapper_code.pop_codegened_graph()
  2197. return result
  2198. def codegen_subgraph(self, parent_graph: GraphLowering) -> None:
  2199. """
  2200. This is a more compact version of the `codegen()` above
  2201. where we codegen this graph as a subgraph of some parent
  2202. graph. The parent graph is passed as an argument: the
  2203. intention is to inline codegening of the subgraph in
  2204. the parent graph's wrapper code (including the generated
  2205. kernels). The wrapper code is not finalized (via `.generate()`
  2206. call), as this will be done in the parent graph's `codegen()`.
  2207. """
  2208. with dynamo_timed("GraphLowering.codegen_subgraph", log_pt2_compile_event=True):
  2209. self.wrapper_code = parent_graph.wrapper_code
  2210. self.device_ops = parent_graph.device_ops
  2211. self.cpp_wrapper = parent_graph.cpp_wrapper
  2212. self.device_types = parent_graph.device_types
  2213. self.device_idxs = parent_graph.device_idxs
  2214. self.device_type = parent_graph.device_type
  2215. self._update_scheduler()
  2216. self.scheduler.codegen()
  2217. def count_bytes(
  2218. self,
  2219. ) -> tuple[
  2220. int, list[tuple[BaseSchedulerNode, int]], list[tuple[BaseSchedulerNode, float]]
  2221. ]:
  2222. total_bytes = 0
  2223. node_counts = []
  2224. node_runtimes = []
  2225. for node in self.scheduler.nodes:
  2226. num_bytes = node.get_read_write_buffers_sizes()
  2227. total_bytes += num_bytes
  2228. node_counts.append((node, num_bytes // 4))
  2229. node_runtimes.append((node, node.get_estimated_runtime()))
  2230. return total_bytes, node_counts, node_runtimes
  2231. # No-op to be patched for unit tests
  2232. save_output_code: Optional[Callable[[str], None]] = None
  2233. def compile_to_module(self) -> CompiledModule:
  2234. with dynamo_timed(
  2235. "GraphLowering.compile_to_module",
  2236. phase_name="code_gen",
  2237. log_pt2_compile_event=True,
  2238. dynamo_compile_column_us="inductor_code_gen_cumulative_compile_time_us",
  2239. ):
  2240. return self._compile_to_module()
  2241. def _compile_to_module(self) -> CompiledModule:
  2242. # If we're here, we don't have to worry about the kernel code, which is only
  2243. # returned separately in AOTInductor mode.
  2244. wrapper_code, _ = (
  2245. self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  2246. )
  2247. if isinstance(wrapper_code, ValueWithLineMap):
  2248. mod = self._compile_to_module_lines(wrapper_code)
  2249. elif isinstance(wrapper_code, FileBackedGraphModule):
  2250. mod = wrapper_code
  2251. else:
  2252. raise NotImplementedError(
  2253. f"Unrecognized wrapper code type: {type(wrapper_code)}"
  2254. )
  2255. # Logged twice as per https://github.com/pytorch/pytorch/pull/99038#discussion_r1167826029
  2256. # TODO. Revisit this once the logging API is more mature
  2257. assert mod.__file__ is not None
  2258. log_module_code(mod.__file__)
  2259. log.debug("Output code written to: %s", mod.__file__)
  2260. output_code_log.info("Output code written to: %s", mod.__file__)
  2261. if config.benchmark_kernel:
  2262. print(f"Compiled module path: {mod.__file__}", file=sys.stderr)
  2263. if isinstance(wrapper_code, FileBackedGraphModule):
  2264. V.debug.output_code(mod.__file__)
  2265. V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
  2266. return mod
  2267. def _compile_to_module_lines(
  2268. self, wrapper_code: ValueWithLineMap
  2269. ) -> CompiledModule:
  2270. from .codecache import PyCodeCache
  2271. if config.triton.autotune_at_compile_time:
  2272. # sanitize docstrings in kernel defs (#155006)
  2273. kernel_autotune_defs = self.wrapper_code.kernel_autotune_defs.getvalue()
  2274. kernel_autotune_defs = kernel_autotune_defs.replace('"""', '\\"\\"\\"')
  2275. tuning_code = (
  2276. 'r"""\n'
  2277. + "Compile-time auto-tuning block: \n"
  2278. + kernel_autotune_defs
  2279. + self.wrapper_code.kernel_autotune_calls.getvalue()
  2280. + '"""\n'
  2281. )
  2282. wrapper_code.value = tuning_code + wrapper_code.value
  2283. if GraphLowering.save_output_code is not None:
  2284. GraphLowering.save_output_code(wrapper_code.value)
  2285. output_code_log.debug("Output code: \n%s", wrapper_code.value)
  2286. inductor_meta = autotune_cache.inductor_meta_from_config()
  2287. AutotuneCacheBundler.begin_compile(inductor_meta, code=wrapper_code.value)
  2288. try:
  2289. linemap = [
  2290. (line_no, node.stack_trace) # type: ignore[attr-defined]
  2291. for line_no, node in wrapper_code.line_map
  2292. ]
  2293. key, path = PyCodeCache.write(wrapper_code.value)
  2294. output_code_log.debug("Output code written to: %s", path)
  2295. V.debug.output_code(path)
  2296. V.debug.copy(os.path.splitext(path)[0] + ".debug")
  2297. except Exception:
  2298. trace_structured(
  2299. "inductor_output_code",
  2300. # Just omit the filename, I still want the code though!
  2301. payload_fn=lambda: wrapper_code.value,
  2302. )
  2303. raise
  2304. else:
  2305. trace_structured(
  2306. "inductor_output_code",
  2307. lambda: {
  2308. "filename": path,
  2309. "file_path": os.path.abspath(path),
  2310. },
  2311. payload_fn=lambda: wrapper_code.value,
  2312. )
  2313. with dynamo_timed("PyCodeCache.load_by_key_path", log_pt2_compile_event=True):
  2314. mod = PyCodeCache.load_by_key_path(
  2315. key,
  2316. path,
  2317. linemap=linemap, # type: ignore[arg-type]
  2318. attrs={
  2319. **self.constants,
  2320. **self.torchbind_constants,
  2321. **self.opaque_value_type_classes,
  2322. },
  2323. )
  2324. self.cache_key = key
  2325. self.cache_path = path
  2326. self.cache_linemap = linemap # type: ignore[assignment]
  2327. if config.benchmark_harness and config.profile_bandwidth_output:
  2328. # run the inputs code gen to get the bandwidth info
  2329. mod.benchmark_compiled_module(times=1, repeat=1)
  2330. return mod
  2331. def _get_output_names(self, graph_outputs: list[ir.IRNode]) -> list[str]:
  2332. names = []
  2333. shape_counter = itertools.count(0)
  2334. none_counter = itertools.count(0)
  2335. for node in graph_outputs:
  2336. if isinstance(node, ir.NoneAsConstantBuffer):
  2337. names.append(f"{self.name}_none{next(none_counter)}")
  2338. elif isinstance(node, ir.ShapeAsConstantBuffer):
  2339. names.append(f"{self.name}_shape{next(shape_counter)}")
  2340. else:
  2341. names.append(node.get_name())
  2342. return names
  2343. def get_output_names(self) -> list[str]:
  2344. return self._get_output_names(self.graph_outputs)
  2345. def is_unspec_arg(self, name: str) -> bool:
  2346. # dynamo wraps unspec variable as 0d CPU tensor,
  2347. # need to convert to scalar during codegen (triton only)
  2348. return (
  2349. name in self.graph_inputs
  2350. and self.graph_inputs[name].get_numel() == 1
  2351. and len(self.graph_inputs[name].get_size()) == 0
  2352. and get_device_type(self.graph_inputs[name]) == "cpu"
  2353. ) or name in self.zero_dim_cpu_tensor_list
  2354. class SubgraphLowering(GraphLowering):
  2355. """
  2356. Mostly a helper class for the subgraph lowering. The main goal is to call
  2357. init_wrapper_code with the subgraph related arguments.
  2358. """
  2359. def __init__(self, parent: GraphLowering, *args: Any, **kwargs: Any) -> None:
  2360. self.parent = parent
  2361. super().__init__(*args, **kwargs)
  2362. def init_wrapper_code(
  2363. self,
  2364. is_subgraph: bool = False,
  2365. subgraph_name: Optional[str] = None,
  2366. parent_wrapper_code: Optional[PythonWrapperCodegen] = None,
  2367. partition_signatures: Optional[GraphPartitionSignature] = None,
  2368. ) -> None:
  2369. super().init_wrapper_code(
  2370. is_subgraph=True,
  2371. subgraph_name=self.name,
  2372. parent_wrapper_code=self.parent.wrapper_code,
  2373. )