utils.py 146 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535
  1. from __future__ import annotations
  2. import collections
  3. import contextlib
  4. import dataclasses
  5. import enum
  6. import functools
  7. import importlib
  8. import inspect
  9. import io
  10. import itertools
  11. import logging
  12. import math
  13. import operator
  14. import os
  15. import platform
  16. import re
  17. import shutil
  18. import statistics
  19. import sys
  20. import sysconfig
  21. import tempfile
  22. import textwrap
  23. import time
  24. import unittest
  25. from collections.abc import (
  26. Callable,
  27. Collection,
  28. Generator,
  29. Iterator,
  30. Mapping,
  31. MutableMapping,
  32. MutableSet,
  33. )
  34. from datetime import datetime
  35. from functools import lru_cache
  36. from io import StringIO
  37. from typing import (
  38. Any,
  39. cast,
  40. Concatenate,
  41. Generic,
  42. Literal,
  43. NamedTuple,
  44. Optional,
  45. Protocol,
  46. TYPE_CHECKING,
  47. TypeAlias,
  48. TypeGuard,
  49. TypeVar,
  50. Union,
  51. )
  52. from typing_extensions import dataclass_transform, ParamSpec, Self
  53. from unittest import mock
  54. import sympy
  55. import torch
  56. import torch.utils._pytree as pytree
  57. from torch._inductor.analysis.device_info import datasheet_tops
  58. from torch._inductor.runtime.hints import DeviceProperties
  59. from torch.fx.passes.regional_inductor import _needs_inductor_compile
  60. from torch.utils._dtype_abbrs import dtype_abbrs
  61. from torch.utils._ordered_set import OrderedSet
  62. from torch.utils._pytree import tree_flatten, tree_map_only
  63. OPTIMUS_EXCLUDE_POST_GRAD = [
  64. "activation_quantization_aten_pass",
  65. "inductor_autotune_lookup_table",
  66. ]
  67. from torch.fx.experimental.symbolic_shapes import (
  68. free_symbols,
  69. free_unbacked_symbols,
  70. IterateExprs,
  71. ShapeEnv,
  72. )
  73. if TYPE_CHECKING:
  74. from collections.abc import Iterable, Sequence, ValuesView
  75. from pathlib import Path
  76. from torch import SymBool, SymFloat, SymInt
  77. from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
  78. from torch.fx import GraphModule
  79. from torch.fx.node import Node
  80. from torch.nn.functional import ScalingType # type: ignore[attr-defined]
  81. from .codegen.common import WorkspaceArg
  82. from .codegen.wrapper import PythonWrapperCodegen
  83. from .dependencies import Dep
  84. from .graph import GraphLowering
  85. from .ir import Buffer, ExternKernel, IRNode, Layout, Operation, ReinterpretView
  86. from .output_code import CompiledFxGraph
  87. from .scheduler import BaseSchedulerNode, SchedulerBuffer
  88. GPU_TYPES = ["cuda", "mps", "xpu", "mtia"]
  89. T = TypeVar("T")
  90. # defines here before import torch._dynamo is for avoiding circular import
  91. # when get_gpu_type is imported from dynamo
  92. @functools.cache
  93. def get_gpu_type() -> str:
  94. avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()]
  95. assert len(avail_gpus) <= 1
  96. gpu_type = "cuda" if len(avail_gpus) == 0 else avail_gpus.pop()
  97. return gpu_type
  98. from torch._dynamo.device_interface import get_interface_for_device
  99. from torch._dynamo.utils import detect_fake_mode
  100. from torch.autograd import DeviceType
  101. from torch.autograd.profiler_util import EventList
  102. from torch.fx.passes.graph_transform_observer import GraphTransformObserver
  103. from torch.fx.passes.shape_prop import ShapeProp
  104. from torch.utils._sympy.functions import (
  105. CeilDiv,
  106. CleanDiv,
  107. FloorDiv,
  108. Identity,
  109. ModularIndexing,
  110. )
  111. from torch.utils._sympy.symbol import make_symbol, SymT
  112. from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
  113. from . import config
  114. from .runtime.runtime_utils import ceildiv as runtime_ceildiv
  115. _IS_WINDOWS = sys.platform == "win32"
  116. log = logging.getLogger(__name__)
  117. _T = TypeVar("_T")
  118. VarRanges = dict[sympy.Expr, sympy.Expr]
  119. InputType = Optional[Union[torch.Tensor, int, torch.SymInt]]
  120. XPU_KERNEL_FORMAT = (
  121. "spv" if _IS_WINDOWS else os.getenv("TORCHINDUCTOR_XPU_KERNEL_FORMAT", "zebin")
  122. )
  123. GPU_KERNEL_BIN_EXTS = {
  124. "cuda": ".cubin",
  125. "hip": ".hsaco",
  126. "xpu": f".{XPU_KERNEL_FORMAT}",
  127. }
  128. GPU_ALIGN_BYTES = 16
  129. ALIGNMENT = 16
  130. TMA_ALIGNMENT = 16
  131. TMA_DESCRIPTOR_SIZE = 128
  132. # PyTorch dtypes with valid CUtensorMapDataType mappings.
  133. # Ref: triton/backends/nvidia/include/cuda.h (CUtensorMapDataType enum)
  134. # triton/_internal_testing.py (tma_dtypes test list)
  135. _TMA_SUPPORTED_DTYPES: OrderedSet[torch.dtype] = OrderedSet(
  136. [
  137. torch.uint8,
  138. torch.int8,
  139. torch.uint16,
  140. torch.int16,
  141. torch.uint32,
  142. torch.int32,
  143. torch.int64,
  144. torch.float16,
  145. torch.bfloat16,
  146. torch.float32,
  147. torch.float64,
  148. torch.float8_e4m3fn,
  149. torch.float8_e5m2,
  150. torch.float8_e4m3fnuz,
  151. torch.float8_e5m2fnuz,
  152. ]
  153. )
  154. ALIGN_BYTES = 64
  155. assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2"
  156. def _align(nbytes: int) -> int:
  157. """Round up to the nearest multiple of ALIGN_BYTES"""
  158. return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES
  159. def _is_aligned(v: sympy.Expr) -> bool:
  160. """v can be statically proven to be a multiple of ALIGN_BYTES"""
  161. if isinstance(v, (sympy.Add, sympy.Max)):
  162. return all(map(_is_aligned, v.args))
  163. return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES
  164. class align(sympy.Function):
  165. """Symbolically round up to the nearest multiple of ALIGN_BYTES"""
  166. nargs = (1,)
  167. is_integer = True
  168. @classmethod
  169. def eval(cls, value: sympy.Expr) -> Optional[sympy.Expr]:
  170. if isinstance(value, (int, sympy.Integer)):
  171. return _align(int(value))
  172. if _is_aligned(value):
  173. return value
  174. @dataclasses.dataclass(frozen=True)
  175. class GraphPartitionMap:
  176. """
  177. Mapping from the partition info (e.g., input/output) to the graph info
  178. """
  179. # a unique id of graph partition
  180. id: int
  181. # map partition input/output indices to graph input/output indices. None indicates
  182. # a partition input/output is not a graph input/output.
  183. input_index_mapping: list[Optional[int]]
  184. output_index_mapping: list[Optional[int]]
  185. # name of constants read/written by the graph partition
  186. constant_names: list[str]
  187. def fp8_bench(fn: Callable[[], Any], warmup: int = 25, rep: int = 100) -> float:
  188. """
  189. Returns benchmark results by examining torch profiler events.
  190. This could be more accurate as it doesn't count CPU side overhead.
  191. However, this also requires manually excluding irrelevant event, e.g.
  192. vectorized_elementwise_kernel which is used to fill L2 cache,
  193. various CUDA events, etc, so could also be fragile.
  194. """
  195. fn()
  196. torch.cuda.synchronize()
  197. cache = torch.empty(int(256e6 // 4), dtype=torch.float16, device="cuda")
  198. # Estimate the runtime of the function
  199. start_event = torch.cuda.Event(enable_timing=True)
  200. end_event = torch.cuda.Event(enable_timing=True)
  201. start_event.record()
  202. for _ in range(5):
  203. cache.zero_()
  204. fn()
  205. end_event.record()
  206. torch.cuda.synchronize()
  207. estimate_ms = start_event.elapsed_time(end_event) / 5
  208. # compute number of warmup and repeat
  209. n_warmup = max(1, int(warmup / estimate_ms))
  210. n_repeat = max(1, int(rep / estimate_ms))
  211. # Warm-up
  212. for _ in range(n_warmup):
  213. fn()
  214. start_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
  215. end_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
  216. with torch.profiler.profile(
  217. activities=[
  218. torch.profiler.ProfilerActivity.CUDA,
  219. ]
  220. ) as p:
  221. torch.cuda.synchronize()
  222. for i in range(n_repeat):
  223. cache.zero_()
  224. start_event[i].record()
  225. with torch.cuda.nvtx.range("RunCudaModule"):
  226. fn()
  227. end_event[i].record()
  228. torch.cuda.synchronize()
  229. times = torch.tensor(
  230. [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
  231. )
  232. res = torch.mean(times).item()
  233. log.debug("raw events")
  234. log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1))
  235. filtered_events = EventList(
  236. [
  237. event
  238. for event in p.events()
  239. if (
  240. event.device_type == DeviceType.CUDA
  241. and re.match(r"fused_abs_max_\d", event.name) is not None
  242. )
  243. ]
  244. )
  245. if filtered_events:
  246. res -= (
  247. statistics.mean(event.device_time_total for event in filtered_events)
  248. / 1000.0
  249. )
  250. log.debug("profiling results: %s ms", res)
  251. return res
  252. def do_bench_using_profiling(
  253. fn: Callable[[], Any],
  254. warmup: int = 25,
  255. rep: int = 100,
  256. is_vetted_benchmarking: bool = False,
  257. ) -> float:
  258. # We did't use decorator may_distort_benchmarking_result directly since that
  259. # requires us to import torch._inductor.runtime.benchmarking into global scope.
  260. # Importing torch._inductor.runtime.benchmarking will cause cuda initialization
  261. # (because of calling torch.cuda.available in global scope)
  262. # which cause failure in vllm when it create child processes. Check log:
  263. # https://gist.github.com/shunting314/c194e147bf981e58df095c14874dd65a
  264. #
  265. # Another way to solve the issue is to just move do_bench_using_profiling
  266. # to torch._inductor.runtime.benchmarking and change all the call site.
  267. # But that's not trivial due to so many call sites in and out of pytorch.
  268. from torch._inductor.runtime.benchmarking import may_distort_benchmarking_result
  269. return may_distort_benchmarking_result(_do_bench_using_profiling)(
  270. fn, warmup, rep, is_vetted_benchmarking
  271. )
  272. def _do_bench_using_profiling(
  273. fn: Callable[[], Any],
  274. warmup: int = 25,
  275. rep: int = 100,
  276. is_vetted_benchmarking: bool = False,
  277. ) -> float:
  278. """
  279. Returns benchmark results by examining torch profiler events.
  280. This could be more accurate as it doesn't count CPU side overhead.
  281. However, this also requires manually excluding irrelevant event, e.g.
  282. vectorized_elementwise_kernel which is used to fill L2 cache,
  283. various CUDA events, etc, so could also be fragile.
  284. """
  285. if not is_vetted_benchmarking:
  286. from torch._inductor.runtime.benchmarking import may_ban_benchmarking
  287. may_ban_benchmarking()
  288. device_type = get_gpu_type()
  289. device_type_upper = device_type.upper()
  290. device_interface = get_interface_for_device(device_type)
  291. fn()
  292. device_interface.synchronize()
  293. cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device_type)
  294. # Estimate the runtime of the function
  295. start_event = device_interface.Event(enable_timing=True)
  296. end_event = device_interface.Event(enable_timing=True)
  297. start_event.record()
  298. for _ in range(5):
  299. cache.zero_()
  300. fn()
  301. end_event.record()
  302. device_interface.synchronize()
  303. estimate_ms = start_event.elapsed_time(end_event) / 5
  304. # compute number of warmup and repeat
  305. n_warmup = max(1, int(warmup / estimate_ms))
  306. n_repeat = max(1, int(rep / estimate_ms))
  307. # Warm-up
  308. for _ in range(n_warmup):
  309. fn()
  310. device_interface.synchronize()
  311. with torch.profiler.profile(
  312. activities=[
  313. getattr(torch.profiler.ProfilerActivity, device_type_upper),
  314. ]
  315. ) as p:
  316. # Benchmark
  317. for _ in range(n_repeat):
  318. # we clear the L2 cache before each run
  319. cache.zero_()
  320. # record time of `fn`
  321. fn()
  322. # Record clocks
  323. device_interface.synchronize()
  324. log.debug("raw events")
  325. log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1))
  326. filtered_events = EventList(
  327. [
  328. event
  329. for event in p.events()
  330. if event.device_type == getattr(DeviceType, device_type_upper)
  331. and event.name != "Context Sync"
  332. ]
  333. )
  334. if len(filtered_events) % n_repeat != 0:
  335. raise RuntimeError(
  336. "Failed to divide all profiling events into #repeat groups. "
  337. "#%s events: %d, #repeats: %s",
  338. device_type,
  339. len(filtered_events),
  340. n_repeat,
  341. )
  342. num_event_per_group = len(filtered_events) / n_repeat
  343. actual_events = EventList(
  344. [
  345. event
  346. for i, event in enumerate(filtered_events)
  347. if i % num_event_per_group != 0
  348. ]
  349. )
  350. actual_events._build_tree()
  351. actual_events = actual_events.key_averages()
  352. log.debug("profiling time breakdown")
  353. log.debug(actual_events.table(row_limit=-1))
  354. res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat
  355. log.debug("profiling results: %s ms", res)
  356. return res
  357. @functools.cache
  358. def has_torchvision_roi_align() -> bool:
  359. try:
  360. from torchvision.ops import roi_align # noqa: F401
  361. torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta")
  362. return roi_align is not None and hasattr(
  363. getattr(torch.ops, "torchvision", None), "roi_align"
  364. )
  365. except ImportError:
  366. return False
  367. except RuntimeError as e:
  368. assert "torchvision::nms does not exist" in str(e)
  369. return False
  370. def decode_device(device: Union[Optional[torch.device], str]) -> torch.device:
  371. if device is None:
  372. return torch.tensor(0.0).device # default device
  373. if isinstance(device, str):
  374. device = torch.device(device)
  375. if device.type not in ("cpu", "meta") and device.index is None:
  376. device_interface = get_interface_for_device(device.type)
  377. return torch.device(device.type, index=device_interface.Worker.current_device())
  378. return device
  379. def sympy_product(it: Iterable[sympy.Expr]) -> sympy.Expr:
  380. return functools.reduce(operator.mul, it, sympy.S.One)
  381. def sympy_dot(seq1: Sequence[sympy.Expr], seq2: Sequence[sympy.Expr]) -> sympy.Expr:
  382. assert len(seq1) == len(seq2)
  383. return sympy.expand(sum(a * b for a, b in zip(seq1, seq2)))
  384. def unique(it: Iterable[_T]) -> ValuesView[_T]:
  385. return {id(x): x for x in it}.values()
  386. def ceildiv(
  387. number: Union[int, sympy.Expr], denom: Union[int, sympy.Expr]
  388. ) -> Union[int, sympy.Expr]:
  389. if isinstance(number, sympy.Expr) or isinstance(denom, sympy.Expr):
  390. return CeilDiv(sympy.sympify(number), sympy.sympify(denom))
  391. # TODO: There is a bug in a call to this function, to repro:
  392. # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
  393. # --amp --only YituTechConvBert --dynamic-shapes
  394. assert isinstance(number, int) and isinstance(denom, int), (
  395. f"{number}: {type(number)}, {denom}: {type(denom)}"
  396. )
  397. return runtime_ceildiv(number, denom)
  398. def _type_of(key: Optional[torch.dtype]) -> str:
  399. # Use the function here to get rid of dependencies on the Triton during the codegen.
  400. # Refer to Triton implementation here:
  401. # https://github.com/triton-lang/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238
  402. # `None` is nullptr. Implicitly convert to *i8.
  403. if key is None:
  404. return "*i8"
  405. dtype_str = str(key).split(".")[-1]
  406. tys = {
  407. "bool": "i1",
  408. "float8e4nv": "fp8e4nv",
  409. "float8e5": "fp8e5",
  410. "float8e4b15": "fp8e4b15",
  411. "float8e4b15x4": "fp8e4b15x4",
  412. "float8_e4m3fn": "fp8e4nv",
  413. "float8_e5m2": "fp8e5",
  414. # TODO: remove when support is added in triton
  415. # https://github.com/triton-lang/triton/issues/6054
  416. "float8_e8m0fnu": "u8",
  417. "float4_e2m1fn_x2": "u8",
  418. "float16": "fp16",
  419. "bfloat16": "bf16",
  420. "float32": "fp32",
  421. "float64": "fp64",
  422. "int8": "i8",
  423. "int16": "i16",
  424. "int32": "i32",
  425. "int64": "i64",
  426. "uint8": "u8",
  427. "uint16": "u16",
  428. "uint32": "u32",
  429. "uint64": "u64",
  430. }
  431. # reinterpret can create triton type
  432. tys.update({v: v for v in list(tys.values())})
  433. return key if isinstance(key, str) else f"*{tys[dtype_str]}"
  434. def convert_shape_to_inductor(
  435. lst: Iterable[Union[int, torch.SymInt]],
  436. ) -> list[sympy.Expr]:
  437. """
  438. Gets the shape and stride of a tensor. For non-symbolic tensors, this is
  439. trivial. But for symbolic tensors, we need to map from SymIntNode into
  440. sympy.Expr.
  441. """
  442. return [sympy.sympify(i) for i in lst]
  443. def convert_symint_to_expr(val: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]:
  444. """
  445. Convert SymInt to sympy.Expr, leave int as is.
  446. Unlike sympy.sympify() which converts int to sympy.Integer,
  447. this function preserves int as int and only converts SymInt to Expr.
  448. """
  449. if isinstance(val, torch.SymInt):
  450. return val.node.expr
  451. return val
  452. def convert_to_symint(i: Union[int, sympy.Expr]) -> Union[int, torch.SymInt]:
  453. """
  454. Like convert_shape_to_symint, but operates on a single expression.
  455. """
  456. from .virtualized import V
  457. return (
  458. i
  459. if isinstance(i, int)
  460. else (
  461. int(i)
  462. if isinstance(i, sympy.Integer)
  463. else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
  464. )
  465. )
  466. def convert_shape_to_symint(
  467. lst: Iterable[Union[int, sympy.Expr]],
  468. ) -> list[Union[int, torch.SymInt]]:
  469. """
  470. Takes a list of shapes from Inductor and converts them into symints (or just
  471. ints if all shapes are static).
  472. """
  473. return [convert_to_symint(i) for i in lst]
  474. def is_view(op: torch._ops.OpOverload) -> bool:
  475. """
  476. Does this op overload have aliasing
  477. """
  478. return any(a.alias_info is not None for a in op._schema.arguments)
  479. def is_pointwise_use(
  480. use: Node,
  481. is_pointwise_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
  482. ) -> bool:
  483. """
  484. Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn`
  485. Uses in views ops will follow the views uses
  486. """
  487. if use.op != "call_function":
  488. return False
  489. if not (
  490. isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
  491. ):
  492. return False
  493. target = cast(torch._ops.OpOverload, use.target)
  494. if target is operator.getitem or is_view(target):
  495. return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users)
  496. return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
  497. def gen_gm_and_inputs(
  498. target: Any, args: list[Any], kwargs: dict[str, Any]
  499. ) -> tuple[GraphModule, list[torch.Tensor]]:
  500. g = torch.fx.Graph()
  501. graph_args: list[torch.Tensor] = []
  502. def add_tensor_arg(arg: torch.Tensor) -> Node:
  503. graph_args.append(arg)
  504. return g.placeholder(f"arg{len(graph_args)}")
  505. node = g.call_function(
  506. target, *tree_map_only(torch.Tensor, add_tensor_arg, (args, kwargs))
  507. )
  508. if (
  509. len(target._schema.returns) == 1
  510. and str(target._schema.returns[0].type) == "Tensor"
  511. ):
  512. node = (node,) # type: ignore[assignment]
  513. g.output(node)
  514. gm = torch.fx.GraphModule({}, g)
  515. return gm, graph_args
  516. def synchronize(device: str = "cuda") -> None:
  517. if device == "cpu":
  518. return
  519. device_interface = get_interface_for_device(device)
  520. if device_interface.is_available():
  521. device_interface.synchronize()
  522. def timed(
  523. model: Callable[..., Any],
  524. example_inputs: Sequence[Any],
  525. times: int = 1,
  526. device: str = "cuda",
  527. ) -> float:
  528. synchronize(device)
  529. torch.manual_seed(1337)
  530. t0 = time.perf_counter()
  531. for _ in range(times):
  532. result = model(*example_inputs)
  533. synchronize(device)
  534. t1 = time.perf_counter()
  535. # GC the result after timing
  536. assert result is not None # type: ignore[possibly-undefined]
  537. return t1 - t0
  538. def print_performance(
  539. model: Callable[..., Any],
  540. example_inputs: Sequence[Any] = (),
  541. times: int = 10,
  542. repeat: int = 10,
  543. baseline: float = 1.0,
  544. device: str = "cuda",
  545. ) -> float:
  546. timings = torch.tensor(
  547. [timed(model, example_inputs, times, device) for _ in range(repeat)]
  548. )
  549. took = torch.median(timings) / times
  550. print(f"{took / baseline:.6f}")
  551. return took.item()
  552. def precompute_method(obj: Any, method: str) -> None:
  553. """Replace obj.method() with a new method that returns a precomputed constant."""
  554. result = getattr(obj, method)()
  555. setattr(obj, method, lambda: result)
  556. def precompute_methods(obj: Any, methods: list[str]) -> None:
  557. """Replace methods with new methods that returns a precomputed constants."""
  558. for method in methods:
  559. precompute_method(obj, method)
  560. def cmp(a: int, b: int) -> int:
  561. return int(a > b) - int(a < b)
  562. def pad_listlike(x: Union[int, Sequence[int]], size: int) -> Sequence[int]:
  563. if isinstance(x, int):
  564. return [x] * size
  565. if len(x) == 1:
  566. return type(x)([x[0]]) * size # type: ignore[call-arg, operator, return-value]
  567. return x
  568. # Used to ensure that iterating over a set is deterministic
  569. def tuple_sorted(x: tuple[_T, ...]) -> list[_T]:
  570. if len(x) == 0:
  571. return []
  572. def sort_func(elem: _T) -> str:
  573. if isinstance(elem, str):
  574. return elem
  575. from .scheduler import BaseSchedulerNode
  576. assert isinstance(elem, BaseSchedulerNode)
  577. return elem.get_name()
  578. return sorted(x, key=sort_func)
  579. P = ParamSpec("P")
  580. RV = TypeVar("RV", covariant=True)
  581. FN_TYPE = Callable[Concatenate[Any, P], RV]
  582. class CachedMethod(Protocol, Generic[P, RV]):
  583. @staticmethod
  584. def clear_cache(cache: Any) -> None: ...
  585. def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: ...
  586. # See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
  587. def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
  588. name = fn.__name__
  589. key = f"__{name}_cache"
  590. # wrapper is likely on the hot path, compile a specialized version of it
  591. ctx = {"fn": fn}
  592. exec(
  593. f"""\
  594. def {name}_cache_on_self(self):
  595. try:
  596. return self.{key}
  597. except AttributeError:
  598. pass
  599. rv = fn(self)
  600. object.__setattr__(self, "{key}", rv)
  601. return rv
  602. """.lstrip(),
  603. ctx,
  604. )
  605. wrapper = functools.wraps(fn)(ctx[f"{name}_cache_on_self"])
  606. def clear_cache(self: Any) -> None:
  607. if hasattr(self, key):
  608. delattr(self, key)
  609. wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
  610. return wrapper # type: ignore[return-value]
  611. def cache_property_on_self(
  612. fn: Callable[Concatenate[Any, P], RV],
  613. ) -> CachedMethod[P, RV]:
  614. """
  615. Variant of cache_on_self for properties. The only difference is the type signature.
  616. """
  617. return cache_on_self(fn)
  618. def cache_on_self_and_args(
  619. class_name: str,
  620. ) -> Callable[[FN_TYPE[P, RV]], FN_TYPE[P, RV]]:
  621. # include both class_name and fn_name in the key to support `super().fn(self, **args, **kwargs)` calls.
  622. def wrapper(
  623. fn: FN_TYPE[P, RV],
  624. ) -> FN_TYPE[P, RV]:
  625. key = f"__{class_name}_{fn.__name__}_cache"
  626. # wrapper is likely on the hot path, compile a specialized version of it
  627. ctx = {"fn": fn}
  628. exec(
  629. f"""\
  630. def inner(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV:
  631. args_kwargs = (args, tuple(sorted(kwargs.items())))
  632. if not hasattr(self, "{key}"):
  633. object.__setattr__(self, "{key}", {{}})
  634. cache = self.{key}
  635. try:
  636. return cache[args_kwargs]
  637. except KeyError:
  638. pass
  639. rv = fn(self, *args, **kwargs)
  640. cache[args_kwargs] = rv
  641. return rv
  642. """.lstrip(),
  643. ctx,
  644. )
  645. inner = functools.wraps(fn)(ctx["inner"])
  646. def clear_cache(self: Any) -> None:
  647. if hasattr(self, key):
  648. delattr(self, key)
  649. inner.clear_cache = clear_cache # type: ignore[attr-defined]
  650. return inner
  651. return wrapper
  652. def aggregate_origins(
  653. node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
  654. ) -> OrderedSet[Node]:
  655. from . import ir
  656. if isinstance(node_schedule, list):
  657. return functools.reduce(
  658. operator.or_,
  659. [
  660. # pyrefly: ignore [missing-attribute]
  661. node.node.origins
  662. for node in node_schedule
  663. if hasattr(node, "node") and node.node
  664. ],
  665. OrderedSet(),
  666. )
  667. elif isinstance(node_schedule, ir.ExternKernel):
  668. return node_schedule.origins
  669. else:
  670. return OrderedSet()
  671. def get_fused_kernel_name(
  672. node_schedule: Sequence[BaseSchedulerNode],
  673. descriptive_names: Literal[True, "torch", "original_aten", "inductor_node"],
  674. ) -> str:
  675. all_origins = aggregate_origins(node_schedule)
  676. if descriptive_names == "original_aten":
  677. def get_origin_meta_str(origin):
  678. original_aten = origin.meta["original_aten"]
  679. key = ""
  680. if isinstance(original_aten, torch._ops.OpOverload):
  681. key = original_aten._overloadpacket.__name__
  682. elif isinstance(original_aten, torch._ops.HigherOrderOperator):
  683. key = str(original_aten.name())
  684. return key
  685. # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
  686. sources = [
  687. get_origin_meta_str(origin)
  688. for origin in all_origins
  689. if origin.op == "call_function"
  690. and "original_aten" in origin.meta
  691. and origin.meta["original_aten"] is not None
  692. ]
  693. sources = sorted(OrderedSet(sources))
  694. elif descriptive_names == "torch":
  695. # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
  696. sources = []
  697. for origin in all_origins:
  698. if origin.op == "call_function":
  699. source_fn = None
  700. suffix = ""
  701. if "source_fn_stack" in origin.meta:
  702. source_fn = origin.meta["source_fn_stack"][-1]
  703. elif "fwd_source_fn_stack" in origin.meta:
  704. # backward nodes have "fwd_source_fn_stack" instead
  705. source_fn = origin.meta["fwd_source_fn_stack"][-1]
  706. suffix = "backward"
  707. if not source_fn:
  708. continue
  709. if isinstance(source_fn[1], str):
  710. sources.append(source_fn[1] + suffix)
  711. else:
  712. sources.append(source_fn[1].__name__ + suffix)
  713. sources = sorted(OrderedSet(sources))
  714. elif descriptive_names == "inductor_node":
  715. sources = [
  716. origin.name for origin in all_origins if origin.op == "call_function"
  717. ]
  718. else:
  719. raise NotImplementedError
  720. return "_".join(["fused"] + sources)
  721. def get_kernel_metadata(
  722. node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
  723. wrapper: PythonWrapperCodegen,
  724. ) -> tuple[str, str]:
  725. """
  726. Retrieves metadata information for a kernel.
  727. Args:
  728. node_schedule (Union[Sequence[BaseSchedulerNode], ExternKernel]):
  729. Either a sequence of BaseSchedulerNode objects or an ExternKernel instance.
  730. wrapper (PythonWrapperCodegen):
  731. An instance of PythonWrapperCodegen, used to define the code comment format.
  732. Returns:
  733. tuple[str, str]:
  734. A tuple containing two strings:
  735. - The first string represents the kernel's metadata.
  736. - The second string represent the kernel's detailed metadata.
  737. """
  738. all_origins = aggregate_origins(node_schedule)
  739. inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"]
  740. from_node_dict = collections.defaultdict(list)
  741. original_aten_dict = collections.defaultdict(list)
  742. # Attempt to sort `inductor_nodes` topologically. Note that the case
  743. # where `inductor_nodes` contains nodes from multiple graph instances
  744. # is not supported. An example of this is conditional statements.
  745. single_graph = None
  746. if inductor_nodes:
  747. unique_graphs = OrderedSet(n.graph for n in inductor_nodes)
  748. if len(unique_graphs) == 1:
  749. single_graph = inductor_nodes[0].graph
  750. # create a map of idx -> node and cache it
  751. if not hasattr(single_graph, "_inductor_kernel_metadata_node_to_idx_map"):
  752. node_to_idx_map = {n: idx for idx, n in enumerate(single_graph.nodes)}
  753. single_graph._inductor_kernel_metadata_node_to_idx_map = node_to_idx_map # type: ignore[attr-defined]
  754. inductor_nodes.sort(
  755. key=lambda n: single_graph._inductor_kernel_metadata_node_to_idx_map[n] # type: ignore[attr-defined]
  756. )
  757. for node in inductor_nodes:
  758. if "original_aten" in node.meta and node.meta["original_aten"] is not None:
  759. original_aten = node.meta["original_aten"]
  760. key = None
  761. if isinstance(original_aten, torch._ops.OpOverload):
  762. key = str(original_aten._overloadpacket)
  763. elif isinstance(original_aten, torch._ops.HigherOrderOperator):
  764. key = str(original_aten.name())
  765. if key:
  766. original_aten_dict[key].append(node.name)
  767. if "from_node" in node.meta:
  768. key = node.meta["from_node"][0].name
  769. from_node_dict[key].append(node.name)
  770. elif node.meta.get("partitioner_tag") == "is_backward":
  771. # backward nodes currently don't have a "from node"
  772. from_node_dict[node.name].append(node.name)
  773. sort_str = "Topologically Sorted" if single_graph is not None else "Unsorted"
  774. metadata = (
  775. f"{wrapper.comment} {sort_str} Source Nodes: [{', '.join(from_node_dict.keys())}], "
  776. f"Original ATen: [{', '.join(original_aten_dict.keys())}]"
  777. )
  778. # trace back to original node here
  779. detailed_metadata = [f"{wrapper.comment} Source node to ATen node mapping:"]
  780. for original_node, nodes in sorted(from_node_dict.items()):
  781. detailed_metadata.append(
  782. f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}"
  783. )
  784. # print the aot_autograd graph fragment
  785. if single_graph is not None:
  786. from . import ir
  787. detailed_metadata.append(f"{wrapper.comment} Graph fragment:")
  788. all_reads: OrderedSet[str] = OrderedSet()
  789. all_writes: list[str] = []
  790. if not isinstance(node_schedule, ir.ExternKernel):
  791. from .virtualized import V
  792. def get_buffer_info(
  793. buffer: Union[ir.TensorBox, ir.Buffer, ir.TorchBindObject], rw_name: str
  794. ) -> tuple[str, ir.Layout | None]:
  795. if isinstance(buffer, ir.TensorBox) and isinstance(
  796. buffer.data, ir.StorageBox
  797. ):
  798. origin_node = buffer.data.data.origin_node
  799. else:
  800. origin_node = buffer.origin_node
  801. if origin_node is None:
  802. # use the read/write name if no origin node is found
  803. name = rw_name
  804. else:
  805. name = origin_node.name
  806. try:
  807. layout = buffer.get_layout()
  808. except NotImplementedError:
  809. layout = None
  810. return name, layout
  811. def stringify_shape(shape: Iterable[int]) -> str:
  812. return f"[{', '.join([str(x) for x in shape])}]"
  813. def stringfy_layout(layout: ir.Layout | None) -> str:
  814. if layout is None:
  815. return ""
  816. shape_annotation = f"{stringify_shape(layout.size)}"
  817. stride_annotation = f"{stringify_shape(layout.stride)}"
  818. device_annotation = f"{layout.device}"
  819. return (
  820. f'"{dtype_abbrs[layout.dtype]}{shape_annotation}'
  821. f'{stride_annotation}{device_annotation}"'
  822. )
  823. for n in node_schedule:
  824. if not hasattr(n, "read_writes") or n.read_writes is None:
  825. continue
  826. if hasattr(n.read_writes, "reads") and n.read_writes.reads is not None:
  827. for r in n.read_writes.reads:
  828. # Remove the dupricated inputs
  829. if r.name in all_reads:
  830. continue
  831. all_reads.add(r.name)
  832. buffer = V.graph.try_get_buffer(r.name)
  833. if buffer is None:
  834. continue
  835. input_name, layout = get_buffer_info(buffer, r.name)
  836. detailed_metadata.append(
  837. f"{wrapper.comment} %{input_name} : Tensor "
  838. f"{stringfy_layout(layout)} = PlaceHolder[target={input_name}]"
  839. )
  840. if (
  841. hasattr(n.read_writes, "writes")
  842. and n.read_writes.writes is not None
  843. ):
  844. for w in n.read_writes.writes:
  845. buffer = V.graph.try_get_buffer(w.name)
  846. if buffer is None:
  847. continue
  848. output_name, _ = get_buffer_info(buffer, w.name)
  849. all_writes.append("%" + output_name)
  850. for node in inductor_nodes:
  851. detailed_metadata.append(
  852. f"{wrapper.comment} {node.format_node(include_tensor_metadata=True)}"
  853. )
  854. detailed_metadata.append(f"{wrapper.comment} return {','.join(all_writes)}")
  855. return metadata, "\n".join(detailed_metadata)
  856. def dominated_nodes(
  857. initial_queue: Iterable[torch.fx.Node],
  858. skip_filter: Optional[Callable[[Any], bool]] = None,
  859. ) -> OrderedSet[torch.fx.Node]:
  860. """Returns the set of nodes whose values depend on those within initial_queue"""
  861. initial_queue = list(initial_queue)
  862. dominated_set = OrderedSet(initial_queue)
  863. while initial_queue:
  864. node = initial_queue.pop()
  865. for user in node.users:
  866. if skip_filter and skip_filter(user):
  867. continue
  868. if user not in dominated_set:
  869. dominated_set.add(user)
  870. initial_queue.append(user)
  871. return dominated_set
  872. def gather_origins(
  873. args: Sequence[IRNode], kwargs: dict[str, IRNode]
  874. ) -> OrderedSet[torch.fx.Node]:
  875. from . import ir
  876. def is_unrealized_node(n: IRNode) -> bool:
  877. if isinstance(n, ir.TensorBox):
  878. return is_unrealized_node(n.data)
  879. if isinstance(n, ir.StorageBox):
  880. return is_unrealized_node(n.data)
  881. return isinstance(n, ir.IRNode) and not isinstance(
  882. n,
  883. (
  884. ir.ComputedBuffer,
  885. ir.InputsKernel,
  886. ir.InputBuffer,
  887. ir.TemplateBuffer,
  888. ),
  889. )
  890. # kwargs and args may include a container of node, for example torch.cat([t1, t2])
  891. # flatten them before search the unrealized nodes
  892. kwargs_flatten, _ = tree_flatten(kwargs)
  893. kwargs_origins = [val.origins for val in kwargs_flatten if is_unrealized_node(val)]
  894. args_flatten, _ = tree_flatten(args)
  895. args_origins = [val.origins for val in args_flatten if is_unrealized_node(val)]
  896. return OrderedSet(itertools.chain(*args_origins, *kwargs_origins))
  897. def sympy_str(expr: sympy.Expr) -> str:
  898. """
  899. Normal sympy str is very slow, this is a lot faster. The result are
  900. somewhat worse, as it doesn't do as much simplification. So don't
  901. use this for final codegen.
  902. """
  903. def is_neg_lead(expr: sympy.Expr) -> bool:
  904. return (
  905. isinstance(expr, sympy.Mul) and len(expr.args) == 2 and expr.args[0] == -1
  906. )
  907. def sympy_str_add(expr: sympy.Expr) -> str:
  908. if isinstance(expr, sympy.Add):
  909. # Special case 'a - b'. Note that 'a - b - c' will still appear as
  910. # 'a + -1 * b + -1 * c'.
  911. if len(expr.args) == 2 and is_neg_lead(expr.args[1]):
  912. return f"{sympy_str_mul(expr.args[0])} - {sympy_str_mul(expr.args[1].args[1])}"
  913. else:
  914. return " + ".join(map(sympy_str_mul, expr.args))
  915. else:
  916. return sympy_str_mul(expr)
  917. def sympy_str_mul(expr: sympy.Expr) -> str:
  918. if isinstance(expr, sympy.Mul):
  919. if is_neg_lead(expr):
  920. # Special case '-a'. Note that 'a * -b' will still appear as
  921. # '-1 * a * b'.
  922. return f"-{sympy_str_atom(expr.args[1])}"
  923. else:
  924. return " * ".join(map(sympy_str_atom, expr.args))
  925. else:
  926. return sympy_str_atom(expr)
  927. def sympy_str_atom(expr: sympy.Expr) -> str:
  928. if isinstance(expr, sympy.Symbol):
  929. return expr.name
  930. elif isinstance(expr, (sympy.Add, sympy.Mul)):
  931. return f"({sympy_str_add(expr)})"
  932. elif isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)):
  933. return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
  934. else:
  935. return str(expr)
  936. return sympy_str_add(expr)
  937. def get_bounds_index_expr(index: sympy.Expr) -> ValueRanges[Any]:
  938. from .virtualized import V
  939. # If this expression does not come from an FX node, we compute its bounds
  940. if (
  941. config.compute_all_bounds
  942. and (fx_node := getattr(V.interpreter, "current_node", None))
  943. and fx_node.target != "index_expr"
  944. ):
  945. return bound_sympy(index)
  946. else:
  947. return ValueRanges.unknown()
  948. def prefix_is_reduction(prefix: str) -> bool:
  949. return prefix[0] == "r"
  950. def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
  951. """
  952. Used to generate an integer-nonnegative symbol.
  953. """
  954. # This should never be used for creating shape/stride symbols, as those
  955. # should all be allocated before Inductor.
  956. assert prefix != SymT.SIZE
  957. # NOTE: shape symbols are positive (> 0), but index variables are only
  958. # non-negative (>= 0).
  959. return make_symbol(prefix, idx, integer=True, nonnegative=True)
  960. def generate_assert(check: bool) -> bool:
  961. return (check or config.debug_index_asserts) and config.assert_indirect_indexing
  962. def sympy_index_symbol(name: str) -> sympy.Symbol:
  963. """
  964. Used to generate an integer-nonnegative symbol.
  965. """
  966. # This should never be used for creating shape/stride symbols, as those
  967. # should all be allocated before Inductor.
  968. assert name[0] != "s"
  969. # NOTE: shape symbols are positive (> 0), but index variables are only
  970. # non-negative (>= 0).
  971. return sympy.Symbol(name, integer=True, nonnegative=True)
  972. def sympy_subs(expr: sympy.Expr, replacements: dict[sympy.Expr, Any]) -> sympy.Expr:
  973. """
  974. When the passed replacement symbol v is a string, it is converted to a symbol with name v that
  975. have the same replaced expression integer and nonnegative properties.
  976. """
  977. def to_symbol(
  978. replaced: sympy.Expr, replacement: Union[sympy.Expr, str]
  979. ) -> sympy.Symbol:
  980. assert isinstance(replaced, sympy.Expr)
  981. if isinstance(replacement, str):
  982. return sympy.Symbol(
  983. replacement,
  984. integer=replaced.is_integer, # type: ignore[attr-defined]
  985. nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
  986. )
  987. else:
  988. return replacement
  989. # xreplace is faster than subs, but is way more picky
  990. return sympy.sympify(expr).xreplace(
  991. {k: to_symbol(k, v) for k, v in replacements.items()}
  992. )
  993. def is_symbolic(a: Any) -> TypeGuard[Union[torch.SymInt, torch.Tensor]]:
  994. return isinstance(a, torch.SymInt) or (
  995. isinstance(a, torch.Tensor) and a._has_symbolic_sizes_strides
  996. )
  997. def any_is_symbolic(*args: Any) -> bool:
  998. return any(is_symbolic(a) for a in args)
  999. # Ops that are fundamentally incompatible with CUDA graph capture
  1000. # (e.g., CPU synchronization, dynamic memory allocation, etc.)
  1001. FORBIDDEN_CUDAGRAPH_OPS = frozenset(
  1002. [
  1003. "aten._fused_moving_avg_obs_fq_helper.default",
  1004. "aten._fused_moving_avg_obs_fq_helper_functional.default",
  1005. "fbgemm.dense_to_jagged.default",
  1006. "fbgemm.jagged_to_padded_dense.default",
  1007. "run_and_save_rng_state",
  1008. "run_with_rng_state",
  1009. "aten._local_scalar_dense",
  1010. # Technically, it's not necessary to ban this, because an
  1011. # assert_scalar with constant arguments can be validly run
  1012. # with CUDA graphs, but the operator is also pointless with
  1013. # constant arguments, so might as well ban
  1014. "aten._assert_scalar",
  1015. ]
  1016. )
  1017. def get_first_incompatible_cudagraph_node(
  1018. gm: torch.fx.GraphModule,
  1019. ) -> Optional[torch.fx.Node]:
  1020. from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
  1021. for node in gm.graph.nodes:
  1022. if is_cudagraph_unsafe_fx_node(node):
  1023. return node
  1024. if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val):
  1025. return node
  1026. return None
  1027. def output_node(gm: torch.fx.GraphModule) -> Node:
  1028. """Get the output node from an FX graph"""
  1029. last_node = next(iter(reversed(gm.graph.nodes)))
  1030. assert last_node.op == "output"
  1031. return last_node
  1032. def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
  1033. placeholder_nodes = gm.graph.find_nodes(op="placeholder")
  1034. input_devices: OrderedSet[torch.device] = OrderedSet(
  1035. node.meta["val"].device
  1036. for node in placeholder_nodes
  1037. if isinstance(node.meta.get("val"), torch.Tensor)
  1038. )
  1039. out_arg = output_node(gm).args[0] # type: ignore[union-attr]
  1040. out_args = out_arg if isinstance(out_arg, tuple) else (out_arg,)
  1041. out_devices: OrderedSet[torch.device] = OrderedSet(
  1042. arg.meta["val"].device
  1043. for arg in out_args
  1044. if isinstance(arg, torch.fx.Node)
  1045. and isinstance(arg.meta.get("val"), torch.Tensor)
  1046. )
  1047. return input_devices | out_devices
  1048. import gc
  1049. def unload_xpu_triton_pyds() -> None:
  1050. # unload __triton_launcher.pyd
  1051. for module_name in list(sys.modules.keys()):
  1052. if not module_name.startswith("torch._inductor.runtime.compile_tasks."):
  1053. continue
  1054. m = sys.modules[module_name]
  1055. for attr_name in m.__dict__:
  1056. if attr_name.startswith("triton_"):
  1057. kernel = getattr(m, attr_name)
  1058. if isinstance(
  1059. kernel, torch._inductor.runtime.triton_heuristics.CachingAutotuner
  1060. ):
  1061. for result in kernel.compile_results:
  1062. if isinstance(
  1063. result,
  1064. torch._inductor.runtime.triton_heuristics.TritonCompileResult,
  1065. ):
  1066. # pyrefly: ignore [missing-attribute]
  1067. result.kernel.run.mod.__del__()
  1068. del sys.modules[module_name]
  1069. # unload spirv_utils.pyd
  1070. if "triton.runtime.driver" in sys.modules:
  1071. mod = sys.modules["triton.runtime.driver"]
  1072. del type(mod.driver.active.utils).instance
  1073. del mod.driver.active.utils
  1074. gc.collect()
  1075. _registered_caches: list[Any] = []
  1076. def clear_on_fresh_cache(obj: Any) -> Any:
  1077. """
  1078. Use this decorator to register any caches that should be cache_clear'd
  1079. with fresh_cache().
  1080. """
  1081. if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
  1082. raise AttributeError(f"{obj} does not have a cache_clear method")
  1083. _registered_caches.append(obj)
  1084. return obj
  1085. def clear_caches() -> None:
  1086. """
  1087. Clear all registered caches.
  1088. """
  1089. for obj in _registered_caches:
  1090. obj.cache_clear()
  1091. @contextlib.contextmanager
  1092. def _set_env(key: str, value: str) -> Iterator[None]:
  1093. """Thread-safe env var set/restore using atomic C-level lookups.
  1094. We avoid mock.patch.dict(os.environ, ...) because it internally calls
  1095. os.environ.copy(), which iterates all env var keys then fetches values in
  1096. separate steps. That approach is not atomic and can race with background threads
  1097. (e.g. Triton async compilation) modifying the environment, causing KeyError,
  1098. so we use os.environ.get() for individual keys which is an atomic C-level lookup.
  1099. """
  1100. old = os.environ.get(key)
  1101. try:
  1102. os.environ[key] = value
  1103. yield
  1104. finally:
  1105. if old is None:
  1106. os.environ.pop(key, None)
  1107. else:
  1108. os.environ[key] = old
  1109. @contextlib.contextmanager
  1110. def fresh_cache(
  1111. cache_entries: Optional[dict[str, Any]] = None,
  1112. dir: Optional[str] = None,
  1113. delete: bool = True,
  1114. ) -> Iterator[None]:
  1115. """
  1116. Contextmanager that provides a clean tmp cachedir for pt2 caches.
  1117. Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
  1118. generated with this cache instance.
  1119. """
  1120. clear_caches()
  1121. from torch._inductor.cpp_builder import normalize_path_separator
  1122. inductor_cache_dir = normalize_path_separator(tempfile.mkdtemp(dir=dir))
  1123. try:
  1124. with _set_env("TORCHINDUCTOR_CACHE_DIR", inductor_cache_dir):
  1125. log.debug("Using inductor cache dir %s", inductor_cache_dir)
  1126. triton_cache_dir = normalize_path_separator(
  1127. os.path.join(inductor_cache_dir, "triton")
  1128. )
  1129. with _set_env("TRITON_CACHE_DIR", triton_cache_dir):
  1130. yield
  1131. if isinstance(cache_entries, dict):
  1132. assert len(cache_entries) == 0, "expected empty cache_entries dict"
  1133. if os.path.exists(triton_cache_dir):
  1134. files = os.listdir(triton_cache_dir)
  1135. cache_entries.update(
  1136. {
  1137. f: os.path.getsize(os.path.join(triton_cache_dir, f))
  1138. for f in files
  1139. if ".lock" not in f
  1140. }
  1141. )
  1142. if delete:
  1143. if is_windows() and torch.xpu.is_available():
  1144. unload_xpu_triton_pyds()
  1145. shutil.rmtree(
  1146. inductor_cache_dir,
  1147. # Let's not fail if we can't clean up the temp dir. Also note that for
  1148. # Windows, we can't delete the loaded modules because the module binaries
  1149. # are open.
  1150. ignore_errors=is_windows(),
  1151. onerror=lambda func, path, exc_info: log.warning(
  1152. "Failed to remove temporary cache dir at %s",
  1153. inductor_cache_dir,
  1154. exc_info=exc_info,
  1155. ),
  1156. )
  1157. except Exception:
  1158. log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
  1159. raise
  1160. finally:
  1161. clear_caches()
  1162. # Deprecated functions -- only keeping them for BC reasons
  1163. clear_on_fresh_inductor_cache = clear_on_fresh_cache
  1164. clear_inductor_caches = clear_caches
  1165. fresh_inductor_cache = fresh_cache
  1166. def argsort(seq: Sequence[Any], *, reverse: bool = False) -> list[int]:
  1167. getter = seq.__getitem__
  1168. a_r = range(len(seq))
  1169. # preserve original order for equal strides
  1170. # e.g. if strides are [32, 8, 8, 1]
  1171. # argsort -> [3, 2, 1, 0], rather than
  1172. # [3, 1, 2, 0]
  1173. # i.e. for equal strides in ascending order (reverse=False) an
  1174. # inner dimension should come before an outer dimension, and vice versa
  1175. # for descending
  1176. sort_idx = list(sorted(a_r, key=getter, reverse=True)) # noqa: C413
  1177. if not reverse:
  1178. return list(reversed(sort_idx))
  1179. return sort_idx
  1180. def argsort_sym(
  1181. shape_env: ShapeEnv,
  1182. seq: Sequence[Union[int, torch.SymInt, sympy.Expr]],
  1183. *,
  1184. reverse: bool = False,
  1185. ) -> list[int]:
  1186. def cmp(a: tuple[int, sympy.Expr], b: tuple[int, sympy.Expr]) -> int:
  1187. a_idx, a_val = a
  1188. b_idx, b_val = b
  1189. def evaluate(expr: Union[bool, torch.SymInt, sympy.Expr]) -> bool:
  1190. if isinstance(expr, bool):
  1191. return expr
  1192. return shape_env.evaluate_expr(expr, size_oblivious=True)
  1193. if evaluate(a_val < b_val):
  1194. return -1
  1195. if evaluate(a_val > b_val):
  1196. return 1
  1197. # If strides are the same, prefer the original order.
  1198. # (this matches argsort's algorithm).
  1199. # For strides = [2048, 2048, 16, 1], this is
  1200. # [3, 2, 1, 0].
  1201. if a_idx < b_idx:
  1202. return 1
  1203. if a_idx > b_idx:
  1204. return -1
  1205. return 0
  1206. # Strategy: convert all symints to sympy.Expr, then use a custom comparator
  1207. exprs = [
  1208. (idx, s.node.expr if isinstance(s, torch.SymInt) else s)
  1209. for idx, s in enumerate(seq)
  1210. ]
  1211. exprs = sorted(exprs, key=functools.cmp_to_key(cmp), reverse=reverse)
  1212. result = [idx for idx, _ in exprs]
  1213. return result
  1214. @functools.lru_cache(8)
  1215. def get_dtype_size(dtype: torch.dtype) -> int:
  1216. # TODO: Investigate why uint64 tensor creation causes overflow error:
  1217. # Workaround for RuntimeError in memory size calculation, but underlying cause unclear
  1218. if dtype == torch.uint64:
  1219. return 8
  1220. return torch.empty((), dtype=dtype).element_size()
  1221. class LineContext(NamedTuple):
  1222. context: Any
  1223. @dataclasses.dataclass
  1224. class ValueWithLineMap:
  1225. value: str
  1226. line_map: list[tuple[int, LineContext]]
  1227. class IndentedBuffer:
  1228. tabwidth = 4
  1229. def __init__(self, initial_indent: int = 0) -> None:
  1230. self._lines: list[Union[DeferredLineBase, LineContext, str]] = []
  1231. self._indent = initial_indent
  1232. @contextlib.contextmanager
  1233. def set_tabwidth(self, tabwidth: int) -> Iterator[None]:
  1234. prev = self.tabwidth
  1235. try:
  1236. self.tabwidth = tabwidth
  1237. yield
  1238. finally:
  1239. self.tabwidth = prev
  1240. def getvaluewithlinemap(self) -> ValueWithLineMap:
  1241. buf = StringIO()
  1242. p = 1
  1243. linemap: list[tuple[int, LineContext]] = []
  1244. for li in self._lines:
  1245. if isinstance(li, DeferredLineBase):
  1246. line = li()
  1247. if line is None:
  1248. continue
  1249. elif isinstance(li, LineContext):
  1250. linemap.append((p, li.context))
  1251. continue
  1252. else:
  1253. line = li
  1254. assert isinstance(line, str)
  1255. buf.write(line)
  1256. buf.write("\n")
  1257. p += 1 + line.count("\n")
  1258. return ValueWithLineMap(buf.getvalue(), linemap)
  1259. def getvalue(self) -> str:
  1260. return self.getvaluewithlinemap().value
  1261. def getrawvalue(self) -> str:
  1262. buf = StringIO()
  1263. for li in self._lines:
  1264. if isinstance(li, DeferredLineBase):
  1265. line = li()
  1266. if line is None:
  1267. continue
  1268. elif isinstance(li, LineContext):
  1269. continue
  1270. else:
  1271. line = li
  1272. assert isinstance(line, str)
  1273. # backslash implies line continuation
  1274. if line.endswith("\\"):
  1275. buf.write(line[:-1])
  1276. else:
  1277. buf.write(line)
  1278. buf.write("\n")
  1279. return buf.getvalue()
  1280. def clear(self) -> None:
  1281. self._lines.clear()
  1282. def __bool__(self) -> bool:
  1283. return bool(self._lines)
  1284. def prefix(self) -> str:
  1285. return " " * (self._indent * self.tabwidth)
  1286. def newline(self) -> None:
  1287. self.writeline("\n")
  1288. def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None:
  1289. if isinstance(line, LineContext):
  1290. self._lines.append(line)
  1291. elif isinstance(line, DeferredLineBase):
  1292. self._lines.append(line.with_prefix(self.prefix()))
  1293. elif line.strip():
  1294. self._lines.append(f"{self.prefix()}{line}")
  1295. else:
  1296. self._lines.append("")
  1297. def writelines(
  1298. self, lines: Sequence[Union[LineContext, DeferredLineBase, str]]
  1299. ) -> None:
  1300. for line in lines:
  1301. self.writeline(line)
  1302. def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]:
  1303. @contextlib.contextmanager
  1304. def ctx() -> Iterator[None]:
  1305. self._indent += offset
  1306. try:
  1307. yield
  1308. finally:
  1309. self._indent -= offset
  1310. return ctx()
  1311. def do_indent(self, offset: int = 1) -> None:
  1312. self._indent += offset
  1313. def do_unindent(self, offset: int = 1) -> None:
  1314. self._indent -= offset
  1315. def splice(
  1316. self, other_code: Union[IndentedBuffer, str], strip: bool = False
  1317. ) -> None:
  1318. if isinstance(other_code, IndentedBuffer):
  1319. dedent = float("inf")
  1320. for line in other_code._lines:
  1321. if not isinstance(line, LineContext) and line:
  1322. dedent = min(dedent, len(line) - len(line.lstrip()))
  1323. if math.isinf(dedent):
  1324. dedent = 0
  1325. for line in other_code._lines:
  1326. if isinstance(line, LineContext):
  1327. self._lines.append(line)
  1328. else:
  1329. IndentedBuffer.writeline(self, line[int(dedent) :])
  1330. else:
  1331. other_code = textwrap.dedent(other_code)
  1332. if strip:
  1333. other_code = other_code.lstrip()
  1334. if not other_code:
  1335. return
  1336. other_code = other_code.rstrip()
  1337. for s in other_code.split("\n"):
  1338. self.writeline(s)
  1339. def map(self, func: Callable[[Any], Any]) -> IndentedBuffer:
  1340. res = IndentedBuffer(initial_indent=self._indent)
  1341. res._lines = [func(line) for line in self._lines]
  1342. return res
  1343. def __repr__(self) -> str:
  1344. return f"{type(self)}({self.getvalue()})"
  1345. def __add__(self, other: Self) -> IndentedBuffer:
  1346. assert self._indent == other._indent
  1347. res = IndentedBuffer(initial_indent=self._indent)
  1348. # TODO(rec): or should this be self.__class__(initial_indent=self._indent)?
  1349. res.writelines(self._lines)
  1350. res.writelines(other._lines)
  1351. return res
  1352. def contains(self, new_line: Union[DeferredLineBase, LineContext, str]) -> bool:
  1353. return new_line in self._lines
  1354. class FakeIndentedBuffer(IndentedBuffer):
  1355. def __init__(self) -> None:
  1356. super().__init__()
  1357. def __getattribute__(self, name: str) -> Any:
  1358. if name == "__class__": # Allow access to the class attribute
  1359. return object.__getattribute__(self, name)
  1360. raise RuntimeError(
  1361. f"Tried to call self.{name} on FakeIndentedBuffer. This buffer"
  1362. "is currently used on TritonTemplateKernel to prevent actual"
  1363. "writes to the body without explicitly specifying the body with"
  1364. "`TritonTemplateKernel.set_subgraph_body(name)`"
  1365. )
  1366. @contextlib.contextmanager
  1367. def restore_stdout_stderr() -> Iterator[None]:
  1368. initial_stdout, initial_stderr = sys.stdout, sys.stderr
  1369. try:
  1370. yield
  1371. finally:
  1372. sys.stdout, sys.stderr = initial_stdout, initial_stderr
  1373. class DeferredLineBase:
  1374. """A line that can be 'unwritten' at a later time"""
  1375. def __init__(self, line: str):
  1376. if not line.strip():
  1377. line = ""
  1378. self.line = line
  1379. def __call__(self) -> Union[str, None]:
  1380. """Returns either self.line or None to indicate the line has been 'unwritten'"""
  1381. raise NotImplementedError
  1382. def _new_line(self, line: str) -> Self:
  1383. """Returns a new deferred line with the same condition"""
  1384. raise NotImplementedError
  1385. def with_prefix(self, prefix: str) -> Self:
  1386. return self._new_line(f"{prefix}{self.line}")
  1387. def lstrip(self) -> Self:
  1388. return self._new_line(self.line.lstrip())
  1389. def __getitem__(self, index: Union[int, slice]) -> Self:
  1390. return self._new_line(self.line[index])
  1391. def __bool__(self) -> bool:
  1392. return bool(self.line)
  1393. def __len__(self) -> int:
  1394. return len(self.line)
  1395. class DelayReplaceLine(DeferredLineBase):
  1396. """At end of codegen call `line.replace(key, value_fn())`"""
  1397. def __init__(self, key: str, value_fn: Callable[[], str], line: str):
  1398. super().__init__(line)
  1399. self.key = key
  1400. self.value_fn = value_fn
  1401. def __call__(self) -> str:
  1402. return self.line.replace(self.key, self.value_fn())
  1403. def _new_line(self, line: str) -> DelayReplaceLine:
  1404. return DelayReplaceLine(self.key, self.value_fn, line)
  1405. @functools.cache
  1406. def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool:
  1407. if isinstance(index_or_device, torch.device):
  1408. device = index_or_device
  1409. else:
  1410. device = torch.device(get_gpu_type(), index_or_device)
  1411. prop = DeviceProperties.create(device)
  1412. # SM logic is not relevant to ROCm gpus
  1413. # Arbitrarily skipping the older models
  1414. if torch.version.hip:
  1415. assert prop.major is not None
  1416. if prop.major < 9 or prop.major == 10:
  1417. log.warning("GPU arch does not support max_autotune_gemm mode usage")
  1418. return False
  1419. return True
  1420. min_sms = 16 if device.type == "xpu" else 68 # 3080
  1421. avail_sms = prop.multi_processor_count
  1422. if avail_sms < min_sms:
  1423. log.warning(
  1424. "Not enough SMs to use max_autotune_gemm mode",
  1425. extra={"min_sms": min_sms, "avail_sms": avail_sms},
  1426. )
  1427. return False
  1428. return True
  1429. @functools.lru_cache
  1430. def get_max_num_sms() -> int:
  1431. if torch.xpu.is_available():
  1432. return torch.xpu.get_device_properties().gpu_subslice_count
  1433. return torch.cuda.get_device_properties("cuda").multi_processor_count
  1434. @functools.lru_cache
  1435. def using_b200() -> bool:
  1436. """Returns true if the device is a NVIDIA B200, otherwise returns false."""
  1437. if not torch.cuda.is_available():
  1438. return False
  1439. # compute capability 10.0 or 10.0a is NVIDIA B200
  1440. device_properties = torch.cuda.get_device_properties(torch.cuda.current_device())
  1441. return device_properties.major == 10
  1442. def get_num_sms() -> int:
  1443. """Handle experimental carveout if set otherwise return hardware SM count"""
  1444. # TODO we need to properly guard on this global
  1445. if torch.xpu.is_available():
  1446. return get_max_num_sms()
  1447. carveout = torch._C._get_sm_carveout_experimental()
  1448. return get_max_num_sms() - (carveout if carveout is not None else 0)
  1449. def get_tma_workspace_arg(
  1450. num_tma_descriptors: int,
  1451. device: torch.device,
  1452. num_programs: Optional[int] = None,
  1453. ) -> WorkspaceArg:
  1454. """Builds and returns a WorkspaceArg for the device side TMA workspace buffer."""
  1455. from .codegen.common import WorkspaceArg, WorkspaceZeroMode
  1456. if num_programs is None:
  1457. num_programs = get_num_sms()
  1458. zero_mode = WorkspaceZeroMode.from_bool(False)
  1459. size = num_programs * num_tma_descriptors * TMA_DESCRIPTOR_SIZE
  1460. return WorkspaceArg(
  1461. count=size,
  1462. zero_mode=zero_mode,
  1463. device=device,
  1464. outer_name=WorkspaceArg.unique_name(),
  1465. )
  1466. def _use_template_for_gpu(
  1467. layout: Layout, allowed_layout_dtypes: list[torch.dtype]
  1468. ) -> bool:
  1469. if layout.dtype not in allowed_layout_dtypes:
  1470. log.debug(
  1471. "Not using template since dtype %s is not in allowed layout dtypes %s",
  1472. layout.dtype,
  1473. allowed_layout_dtypes,
  1474. )
  1475. return (
  1476. is_gpu(layout.device.type)
  1477. and layout.dtype in allowed_layout_dtypes
  1478. and is_big_gpu(layout.device)
  1479. )
  1480. def _use_autotune_backend(backend: str) -> bool:
  1481. return backend.upper() in [
  1482. x.strip() for x in config.max_autotune_gemm_backends.upper().split(",")
  1483. ]
  1484. def _use_conv_autotune_backend(backend: str) -> bool:
  1485. return backend.upper() in [
  1486. x.strip() for x in config.max_autotune_conv_backends.upper().split(",")
  1487. ]
  1488. def use_triton_template(
  1489. layout: Layout,
  1490. *,
  1491. enable_int32: bool = False,
  1492. enable_float8: bool = False,
  1493. check_max_autotune: bool = True,
  1494. ) -> bool:
  1495. from .codegen.common import BackendFeature, has_backend_feature
  1496. layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
  1497. if enable_int32:
  1498. layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
  1499. if enable_float8:
  1500. layout_dtypes.extend([torch.float8_e4m3fn, torch.float8_e5m2])
  1501. return (
  1502. (
  1503. (
  1504. is_gpu(layout.device.type)
  1505. and _use_template_for_gpu(layout, layout_dtypes)
  1506. )
  1507. or (layout.device.type == "cpu" and layout.dtype in layout_dtypes)
  1508. )
  1509. # some callers handle max-autotune checking externally
  1510. and (config.max_autotune or config.max_autotune_gemm or not check_max_autotune)
  1511. and _use_autotune_backend("TRITON")
  1512. and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES)
  1513. )
  1514. def can_use_tma(
  1515. *matrices: IRNode, output_layout: Optional[Layout] = None, add_guards: bool = False
  1516. ) -> bool:
  1517. """
  1518. Return True iff *all* supplied tensors satisfy the CUDA TMA constraints
  1519. that Triton relies on today.
  1520. * https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
  1521. A tensor is accepted when:
  1522. * 1 ≤ rank ≤ 5 (cuTensorMapEncodeTiled)
  1523. * dtype in _TMA_SUPPORTED_DTYPES (CUtensorMapDataType enum)
  1524. * Base pointer 16-byte aligned
  1525. * Exactly one contiguous ("inner") dim with stride 1
  1526. * All "outer" dims have 16-byte aligned strides
  1527. * Inner dim size × itemsize is a multiple of 16
  1528. * For 1-byte dtypes (e.g. FP8), inner dim ≥ 32
  1529. """
  1530. from torch.utils._triton import has_triton_tma_device
  1531. from .virtualized import V
  1532. def _aligned(expr_bytes: Union[int, sympy.Expr]) -> bool:
  1533. return V.graph.sizevars.statically_known_multiple_of(expr_bytes, TMA_ALIGNMENT)
  1534. def _is_tma_compatible_layout(layout: Optional[Layout]) -> bool:
  1535. if layout is None:
  1536. return True
  1537. sizes = layout.size
  1538. strides = layout.stride
  1539. dtype = layout.dtype
  1540. # Verify the output is 16-byte aligned
  1541. if not _aligned(layout.offset):
  1542. return False
  1543. return _is_tma_compatible(sizes, strides, dtype)
  1544. def _is_tma_compatible_matrix(m: IRNode) -> bool:
  1545. sizes = m.get_size()
  1546. strides = m.get_stride()
  1547. dtype = m.get_dtype()
  1548. # Base pointer 16-byte aligned
  1549. if m.get_name() in V.graph.unaligned_buffers:
  1550. return False
  1551. if (m_device := m.get_device()) is not None and m_device.type == "xpu":
  1552. return _is_tma_compatible_xpu(sizes, strides, dtype)
  1553. return _is_tma_compatible(sizes, strides, dtype)
  1554. def _is_tma_compatible(
  1555. sizes: Sequence[sympy.Expr],
  1556. strides: Sequence[_IntLike],
  1557. dtype: torch.dtype,
  1558. ) -> bool:
  1559. rank = len(sizes)
  1560. itemsize = dtype.itemsize
  1561. if rank < 1 or rank > 5:
  1562. return False
  1563. if dtype not in _TMA_SUPPORTED_DTYPES:
  1564. return False
  1565. if add_guards:
  1566. sizes_i = V.graph.sizevars.guard_int_seq(sizes)
  1567. strides_i = V.graph.sizevars.guard_int_seq(strides)
  1568. else:
  1569. sizes_i = [V.graph.sizevars.symbolic_hint(s) for s in sizes]
  1570. strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides]
  1571. # Find the single contiguous ("inner") dim
  1572. inner = [
  1573. i
  1574. for i, st in enumerate(strides_i)
  1575. if V.graph.sizevars.statically_known_equals(st, 1)
  1576. ]
  1577. if len(inner) != 1:
  1578. return False
  1579. inner_idx = inner[0]
  1580. # All "outer" dims must have 16-byte aligned strides
  1581. for i, st in enumerate(strides_i):
  1582. if i == inner_idx:
  1583. continue
  1584. if not _aligned(st * itemsize):
  1585. return False
  1586. # Inner dim byte width must be a multiple of 16 B
  1587. inner_dim = sizes_i[inner_idx]
  1588. if not _aligned(inner_dim * itemsize):
  1589. return False
  1590. # 1-byte dtypes (FP8 etc.) need inner dim ≥ 32 for tensor core alignment
  1591. if itemsize == 1 and not V.graph.sizevars.statically_known_geq(inner_dim, 32):
  1592. return False
  1593. return True
  1594. def _is_tma_compatible_xpu(
  1595. sizes: Sequence[sympy.Expr],
  1596. strides: Sequence[_IntLike],
  1597. dtype: torch.dtype,
  1598. ) -> bool:
  1599. # Make sure the last dimension is contiguous
  1600. last_stride = strides[-1]
  1601. last_stride_hint = V.graph.sizevars.symbolic_hint(last_stride)
  1602. if not V.graph.sizevars.statically_known_equals(last_stride_hint, 1):
  1603. return False
  1604. # Triton's type of index is uint32, so all dimensions must fit in uint32
  1605. MAX_UINT32 = 2**32 - 1
  1606. for size in sizes:
  1607. size_hint = V.graph.sizevars.symbolic_hint(size)
  1608. if V.graph.sizevars.statically_known_gt(size_hint, MAX_UINT32):
  1609. return False
  1610. return True
  1611. return (
  1612. has_triton_tma_device()
  1613. and all(_is_tma_compatible_matrix(m) for m in matrices)
  1614. and _is_tma_compatible_layout(output_layout)
  1615. )
  1616. def use_triton_tma_template(
  1617. *matrices: IRNode, output_layout: Layout, add_guards: bool = False
  1618. ) -> bool:
  1619. layout = output_layout if config.triton.enable_template_tma_store else None
  1620. return (
  1621. all(len(m.get_size()) == 2 for m in matrices)
  1622. and can_use_tma(*matrices, output_layout=layout, add_guards=add_guards)
  1623. and config.triton.enable_persistent_tma_matmul
  1624. )
  1625. def use_triton_blackwell_tma_template(
  1626. *matrices: IRNode, output_layout: Layout, add_guards: bool = False
  1627. ) -> bool:
  1628. if not use_triton_tma_template(
  1629. *matrices, output_layout=output_layout, add_guards=add_guards
  1630. ):
  1631. return False
  1632. from torch.utils._triton import has_triton_tensor_descriptor_host_tma
  1633. from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
  1634. # Blackwell template require the tensor descriptor API, not the experimental API.
  1635. return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
  1636. def use_triton_scaling_template(
  1637. scale_option_a: ScalingType,
  1638. scale_option_b: ScalingType,
  1639. scaling_types: list[ScalingType],
  1640. ) -> bool:
  1641. return scale_option_a in scaling_types and scale_option_b in scaling_types
  1642. @functools.lru_cache(maxsize=1)
  1643. def ensure_cute_available() -> bool:
  1644. """Check if CuTeDSL is importable; cache the result for reuse.
  1645. Call ensure_cute_available.cache_clear() after installing CuTeDSL
  1646. in the same interpreter to retry the import.
  1647. """
  1648. try:
  1649. return importlib.util.find_spec("cutlass") is not None
  1650. except ImportError:
  1651. return False
  1652. @functools.lru_cache(maxsize=1)
  1653. def ensure_nv_universal_gemm_available() -> bool:
  1654. """Check if NVIDIA Universal GEMM (cutlass_api) is importable; cache the result for reuse.
  1655. Call ensure_nv_universal_gemm_available.cache_clear() after installing cutlass_api
  1656. in the same interpreter to retry the import.
  1657. """
  1658. try:
  1659. return importlib.util.find_spec("cutlass_api") is not None
  1660. except ImportError:
  1661. return False
  1662. @functools.lru_cache(maxsize=1)
  1663. def ensure_nvmatmul_heuristics_available() -> bool:
  1664. """Check if nvMatmulHeuristics is importable; cache the result for reuse.
  1665. nvMatmulHeuristics provides performance model-based kernel selection
  1666. for NVIDIA GEMM operations.
  1667. Call ensure_nvmatmul_heuristics_available.cache_clear() after installing
  1668. nvMatmulHeuristics in the same interpreter to retry the import.
  1669. """
  1670. try:
  1671. return importlib.util.find_spec("nvMatmulHeuristics") is not None
  1672. except ImportError:
  1673. return False
  1674. def use_blackwell_cutedsl_grouped_mm(
  1675. mat_a: Any,
  1676. mat_b: Any,
  1677. layout: Layout,
  1678. a_is_2d: bool,
  1679. b_is_2d: bool,
  1680. offs: Optional[Any],
  1681. bias: Optional[Any],
  1682. scale_result: Optional[Any],
  1683. ) -> bool:
  1684. """
  1685. Returns True if we can use the blackwell kernel for grouped mm.
  1686. Required conditions:
  1687. 1. CuTeDSL backend is enabled
  1688. 2. CuTeDSL is available
  1689. 3. We are on a blackwell arch
  1690. 4. The dtype is bf16
  1691. 5. Max autotune or max autotune gemm is enabled
  1692. 6. A, B, and the output are 16B aligned
  1693. 7. We are not using dynamic shapes
  1694. 8. A is 2d
  1695. 9. B is 3d
  1696. 10. Offsets are provided
  1697. 11. Bias and Scale are not provided
  1698. """
  1699. if not ensure_cute_available():
  1700. return False
  1701. if not _use_autotune_backend("CUTEDSL"):
  1702. return False
  1703. from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
  1704. if not is_gpu(layout.device.type):
  1705. return False
  1706. if not is_datacenter_blackwell_arch():
  1707. return False
  1708. layout_dtypes = [torch.bfloat16]
  1709. if not _use_template_for_gpu(layout, layout_dtypes):
  1710. return False
  1711. if not (config.max_autotune or config.max_autotune_gemm):
  1712. return False
  1713. # Checks for 16B ptr and stride alignment
  1714. if not can_use_tma(mat_a, mat_b, output_layout=layout):
  1715. return False
  1716. if any(is_dynamic(x) for x in [mat_a, mat_b]):
  1717. return False
  1718. if not a_is_2d or b_is_2d:
  1719. return False
  1720. if offs is None:
  1721. return False
  1722. if bias is not None or scale_result is not None:
  1723. return False
  1724. return True
  1725. def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
  1726. from .virtualized import V
  1727. gemm_size = V.graph.sizevars.optimization_hint(m * n * k, fallback=-1)
  1728. if gemm_size <= 0 or gemm_size < config.cutlass.cutlass_backend_min_gemm_size:
  1729. return False
  1730. from .codegen.cutlass.utils import try_import_cutlass
  1731. # Do not use cutlass template on ROCm
  1732. if torch.version.hip:
  1733. return False
  1734. # output dtype
  1735. # FP32 not supported: https://github.com/pytorch/pytorch/issues/145952
  1736. layout_dtypes = [torch.float16, torch.bfloat16, torch.int32]
  1737. res = (
  1738. _use_template_for_gpu(layout, layout_dtypes)
  1739. and (config.max_autotune or config.max_autotune_gemm)
  1740. and _use_autotune_backend("CUTLASS")
  1741. )
  1742. if res:
  1743. if not try_import_cutlass():
  1744. log.warning(
  1745. "Failed to import CUTLASS lib. Please check whether "
  1746. "_inductor.config.cutlass.cutlass_dir %s is set correctly. "
  1747. "Skipping CUTLASS backend for now.",
  1748. config.cutlass.cutlass_dir,
  1749. )
  1750. return False
  1751. return res
  1752. def use_nv_universal_gemm_template(
  1753. layout: Layout,
  1754. m: _IntLike,
  1755. n: _IntLike,
  1756. k: _IntLike,
  1757. mat_a: IRNode,
  1758. mat_b: IRNode,
  1759. offs: Optional[IRNode] = None,
  1760. g: Optional[_IntLike] = None,
  1761. ) -> bool:
  1762. """
  1763. Return True if we can use the NVIDIA Universal GEMM Template.
  1764. Required conditions:
  1765. 1. NVGEMM backend is enabled
  1766. 2. cutlass_api is available
  1767. 3. We are on a NVIDIA GPU
  1768. 4. Max autotune or max autotune gemm is enabled
  1769. 5. Not in AOT Inductor mode (requires runtime JIT compilation)
  1770. 6. Base pointers are 16-byte aligned
  1771. 7. Shape dimensions are not unbacked symbols
  1772. Note:
  1773. - Shape and stride constraints are handled internally by
  1774. cutlass_api.get_kernels() which filters incompatible kernels.
  1775. - GroupedGemm currently only supports TN layout (column-major B).
  1776. Any other layout will act as a noop and fall back to ATen.
  1777. - Dynamic shapes are supported as long as they have hints
  1778. (from example inputs).
  1779. """
  1780. from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
  1781. if not ensure_cute_available():
  1782. return False
  1783. if not ensure_nv_universal_gemm_available():
  1784. return False
  1785. if not _use_autotune_backend("NVGEMM"):
  1786. return False
  1787. from .virtualized import V
  1788. if V.aot_compilation:
  1789. return False
  1790. if layout.device.type != "cuda" or torch.version.hip:
  1791. return False
  1792. if not (config.max_autotune or config.max_autotune_gemm):
  1793. return False
  1794. # cutlass_api can't handle unbacked symbols because it needs to evaluate
  1795. # shape constraints (e.g., stride divisibility by 8, N/K divisibility by 16).
  1796. # Unbacked symbols have no hint values, causing GuardOnDataDependentSymNode errors.
  1797. dims_to_check = [m, n, k]
  1798. if g is not None:
  1799. dims_to_check.append(g)
  1800. if any(has_free_unbacked_symbols(dim) for dim in dims_to_check):
  1801. return False
  1802. # Base pointer must be 16-byte aligned. cutlass_api can't check this at
  1803. # compile time because it only sees FakeTensors without real data pointers.
  1804. tensors_to_check = [mat_a, mat_b]
  1805. if offs is not None:
  1806. tensors_to_check.append(offs)
  1807. if any(t.get_name() in V.graph.unaligned_buffers for t in tensors_to_check):
  1808. return False
  1809. return True
  1810. def _use_cutlass_for_op(op_name: str) -> bool:
  1811. """Check if CUTLASS should be used for the given operation."""
  1812. enabled_ops = config.cutlass.cutlass_enabled_ops.upper()
  1813. if enabled_ops == "ALL":
  1814. return True
  1815. return op_name.upper() in [x.strip() for x in enabled_ops.split(",")]
  1816. _IntLike: TypeAlias = Union[int, sympy.Expr]
  1817. @functools.cache
  1818. def use_decompose_k_choice(
  1819. m: _IntLike, n: _IntLike, k: _IntLike, threshold_multiple: int = 1
  1820. ) -> bool:
  1821. from torch._inductor.virtualized import V
  1822. decompose_k_threshold = config.triton.decompose_k_threshold * threshold_multiple
  1823. return (
  1824. V.graph.sizevars.statically_known_true(
  1825. sympy.And(
  1826. sympy.Ge(k, decompose_k_threshold * m),
  1827. sympy.Ge(k, decompose_k_threshold * n),
  1828. )
  1829. )
  1830. and not V.graph.aot_mode # TODO: Support AOTI for decomposeK
  1831. and not V.graph.cpp_wrapper
  1832. and config.triton.num_decompose_k_splits > 0
  1833. )
  1834. @functools.cache
  1835. def use_contiguous(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
  1836. """
  1837. Check if we should use the contiguous subgraph transform.
  1838. This transform makes the second matrix contiguous before the matmul.
  1839. """
  1840. contiguous_threshold = config.rocm.contiguous_threshold
  1841. # Similar conditions to decompose_k but for contiguous transform
  1842. from torch._inductor.virtualized import V
  1843. return (
  1844. bool(torch.version.hip) # Only relevant on AMD
  1845. and V.graph.sizevars.statically_known_true(
  1846. sympy.And(
  1847. sympy.Ge(k, contiguous_threshold * m),
  1848. sympy.Ge(k, contiguous_threshold * n),
  1849. )
  1850. )
  1851. and not V.graph.aot_mode
  1852. and not V.graph.cpp_wrapper
  1853. )
  1854. @functools.cache
  1855. def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]:
  1856. # To limit compile time
  1857. k_splits_limit = config.triton.num_decompose_k_splits
  1858. # Hand-tuned
  1859. default_k_splits = [16, 32, 64, 128, 256]
  1860. # If k is a sympy expression, we can't do any splitting
  1861. if isinstance(k, sympy.Expr) and not k.is_number:
  1862. return default_k_splits
  1863. elif k_splits_limit == 0:
  1864. return []
  1865. if (isinstance(m, sympy.Expr) and not m.is_number) or (
  1866. isinstance(n, sympy.Expr) and not n.is_number
  1867. ):
  1868. max_k_split = 256
  1869. else:
  1870. max_k_split = min(k // m, k // n)
  1871. min_k_split = 2
  1872. # Get all divisors of k, k has to be divisible by kPart
  1873. divisors = sympy.divisors(k)
  1874. divisors = [
  1875. divisor
  1876. for divisor in divisors
  1877. if divisor <= max_k_split and divisor >= min_k_split
  1878. ]
  1879. pow_of_2_divisors, mul_of_32_divisors, rest_of_splits = [], [], []
  1880. for d in divisors:
  1881. kPart = k // d
  1882. # Smaller than 128 might not even fit in a single tile, BLOCK_K can be 128
  1883. if kPart < 128:
  1884. continue
  1885. # Power of 2 divisors are best performing, conform to hardware
  1886. if (kPart & kPart - 1) == 0 and kPart >= 128:
  1887. pow_of_2_divisors.append(d)
  1888. # Else check if creates a multiple of 32
  1889. elif kPart % 32 == 0:
  1890. mul_of_32_divisors.append(d)
  1891. # otherwise, take the smallest values
  1892. else:
  1893. rest_of_splits.append(d)
  1894. if config.max_autotune_gemm_search_space == "EXHAUSTIVE":
  1895. return pow_of_2_divisors + mul_of_32_divisors + rest_of_splits
  1896. best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits
  1897. # Otherwise, conform results to k_splits_limit
  1898. return best_splits[:k_splits_limit]
  1899. @functools.cache
  1900. def _rocm_native_device_arch_name(device: str) -> str:
  1901. return torch.cuda.get_device_properties(device).gcnArchName
  1902. @functools.cache
  1903. def try_import_ck_lib() -> tuple[
  1904. Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]
  1905. ]:
  1906. try:
  1907. import ck4inductor # type: ignore[import]
  1908. from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import]
  1909. gen_ops_library,
  1910. gen_ops_preselected,
  1911. )
  1912. from ck4inductor.universal_gemm.op import ( # type: ignore[import]
  1913. CKGemmOperation,
  1914. )
  1915. package_dirname = os.path.dirname(ck4inductor.__file__)
  1916. except ImportError:
  1917. def gen_ops_library() -> list[Any]:
  1918. return []
  1919. def gen_ops_preselected() -> list[Any]:
  1920. return []
  1921. class CKGemmOperation: # type: ignore[no-redef]
  1922. pass
  1923. package_dirname = None
  1924. return package_dirname, gen_ops_library, gen_ops_preselected, CKGemmOperation
  1925. def use_ck_template(layout: Layout) -> bool:
  1926. # config knobs check 1
  1927. if not (config.max_autotune or config.max_autotune_gemm):
  1928. return False
  1929. # platform check
  1930. if not torch.version.hip:
  1931. return False
  1932. # tensors must be on GPU
  1933. if layout.device.type != "cuda":
  1934. return False
  1935. # hardware check
  1936. # if config arch list is not specified, get the native arch from the device properties
  1937. native_arch = _rocm_native_device_arch_name(layout.device)
  1938. requested_archs = {k.split(":")[0]: k for k in config.rocm.arch} or {
  1939. native_arch.split(":")[0]: native_arch
  1940. }
  1941. requested_supported_archs = [
  1942. requested_archs[k]
  1943. for k in requested_archs.keys() & config.rocm.ck_supported_arch
  1944. ]
  1945. if not requested_supported_archs:
  1946. return False
  1947. # supported input dtypes
  1948. if layout.dtype not in [torch.float16, torch.bfloat16, torch.float32]:
  1949. return False
  1950. ck_package_dirname, _, _, _ = try_import_ck_lib()
  1951. if not ck_package_dirname:
  1952. log.warning("Please pip install Composable Kernel package")
  1953. return False
  1954. config.rocm.ck_dir = ck_package_dirname
  1955. return True
  1956. def use_ck_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool:
  1957. from .virtualized import V
  1958. return (
  1959. _use_autotune_backend("CK")
  1960. and use_ck_template(layout)
  1961. and V.graph.sizevars.optimization_hint(m * n * k, fallback=-1) > 0
  1962. )
  1963. def use_ck_tile_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool:
  1964. from .virtualized import V
  1965. return (
  1966. _use_autotune_backend("CKTILE")
  1967. and use_ck_template(layout)
  1968. and V.graph.sizevars.optimization_hint(m * n * k, fallback=-1) > 0
  1969. )
  1970. def use_ck_conv_template(layout: Layout) -> bool:
  1971. return _use_conv_autotune_backend("CK") and use_ck_template(layout)
  1972. def _use_template_for_cpu(layout: Layout) -> bool:
  1973. return (
  1974. config.max_autotune or config.max_autotune_gemm
  1975. ) and layout.device.type == "cpu"
  1976. def use_cpp_bmm_template(
  1977. layout: Layout, mat1: Union[ReinterpretView, Buffer], mat2: IRNode
  1978. ) -> bool:
  1979. from .ir import Layout
  1980. assert isinstance(mat1.layout, Layout)
  1981. # In certain scenarios, such as when the first stride is 0, the entire tensor may not be contiguous.
  1982. # But the 2D matrix within each batch can still be contiguous, allowing us to apply max autotune.
  1983. # So here we specifically check for contiguity within the 2D matrix of each batch.
  1984. mat1_size = mat1.layout.size
  1985. mat1_stride = mat1.layout.stride
  1986. mat1_each_batch_is_contiguous = (
  1987. _use_template_for_cpu(layout)
  1988. and mat1.get_dtype() == torch.float32
  1989. and (len(mat1_size) == 3)
  1990. and (len(mat1_stride) == 3)
  1991. and (mat1_stride[1] == mat1_size[2])
  1992. and (mat1_stride[2] == 1)
  1993. )
  1994. return use_cpp_gemm_template(layout, mat1, mat2, require_constant_mat2=False) and (
  1995. mat1.layout.is_contiguous() or mat1_each_batch_is_contiguous
  1996. )
  1997. def use_cpp_gemm_template(
  1998. layout: Layout,
  1999. mat1: IRNode,
  2000. mat2: IRNode,
  2001. mat2_transposed: bool = False,
  2002. require_constant_mat2: bool = True,
  2003. is_woq_int4: bool = False,
  2004. q_group_size: Optional[int] = None,
  2005. ) -> bool:
  2006. from . import ir
  2007. from .codegen.cpp_micro_gemm import create_micro_gemm
  2008. from .codegen.cpp_utils import get_gemm_template_output_and_compute_dtype
  2009. from .kernel.mm_common import mm_args
  2010. if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"):
  2011. return False
  2012. if not config.cpp.weight_prepack:
  2013. return False
  2014. int8_gemm = mat1.get_dtype() in [torch.uint8, torch.int8]
  2015. layout_dtypes = [torch.float32, torch.bfloat16, torch.half, torch.uint8, torch.int8]
  2016. m, n, k, layout, mat1, mat2 = mm_args(
  2017. mat1,
  2018. mat2,
  2019. out_dtype=layout.dtype if int8_gemm else None,
  2020. mat2_transposed=mat2_transposed,
  2021. use_4x2_dim=is_woq_int4,
  2022. )
  2023. # TODO(jgong5): support dynamic shapes for n or k
  2024. if has_free_symbols((n, k)):
  2025. return False
  2026. if isinstance(mat2, ir.BaseView):
  2027. mat2 = mat2.unwrap_view()
  2028. output_dtype, _ = get_gemm_template_output_and_compute_dtype(mat1.get_dtype())
  2029. micro_gemm = create_micro_gemm(
  2030. "micro_gemm",
  2031. m,
  2032. n,
  2033. k,
  2034. input_dtype=mat1.get_dtype(),
  2035. input2_dtype=mat2.get_dtype(),
  2036. output_dtype=output_dtype,
  2037. num_threads=parallel_num_threads(),
  2038. use_ref=not is_woq_int4,
  2039. q_group_size=q_group_size,
  2040. )
  2041. def is_last_dim_stride1(x: IRNode) -> bool:
  2042. x.freeze_layout()
  2043. return x.get_stride()[-1] == 1
  2044. return (
  2045. layout.dtype in layout_dtypes
  2046. and micro_gemm is not None
  2047. and is_last_dim_stride1(mat1) # TODO(jgong5): support transposed input
  2048. and isinstance(mat2, ir.StorageBox)
  2049. and (mat2.is_module_buffer() or not require_constant_mat2)
  2050. )
  2051. def use_aten_gemm_kernels() -> bool:
  2052. return not (
  2053. config.max_autotune or config.max_autotune_gemm
  2054. ) or _use_autotune_backend("ATEN")
  2055. class DebugDirManager:
  2056. counter = itertools.count(0)
  2057. prev_debug_name: str
  2058. def __init__(self) -> None:
  2059. self.id = next(DebugDirManager.counter)
  2060. def __enter__(self) -> None:
  2061. self.prev_debug_name = torch._dynamo.config.debug_dir_root
  2062. self.new_name = f"{self.prev_debug_name}_tmp_{self.id}"
  2063. torch._dynamo.config.debug_dir_root = self.new_name
  2064. def __exit__(self, *args: Any) -> None:
  2065. shutil.rmtree(self.new_name)
  2066. torch._dynamo.config.debug_dir_root = self.prev_debug_name
  2067. def run_and_get_code(
  2068. fn: Callable[P, _T],
  2069. *args: P.args,
  2070. **kwargs: P.kwargs,
  2071. ) -> tuple[_T, list[str]]:
  2072. from .graph import GraphLowering
  2073. source_codes: OrderedSet[str] = OrderedSet()
  2074. def save_output_code(code: str) -> None:
  2075. source_codes.add(code)
  2076. with mock.patch.object(GraphLowering, "save_output_code", save_output_code):
  2077. torch._dynamo.reset()
  2078. result = fn(*args, **kwargs)
  2079. return result, list(source_codes)
  2080. def run_and_get_kernels(
  2081. fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
  2082. ) -> tuple[_T, list[str]]:
  2083. remove_quote = kwargs.pop("remove_quote", False)
  2084. # pyrefly: ignore [bad-argument-type]
  2085. result, source_codes = run_and_get_code(fn, *args, **kwargs)
  2086. kernels = []
  2087. for code in source_codes:
  2088. kernels.extend(re.findall(r"'''.*?'''", code, re.DOTALL))
  2089. if remove_quote:
  2090. kernels = [kernel[3:-3] for kernel in kernels]
  2091. return result, kernels
  2092. def run_fw_bw_and_get_code(fn: Callable[..., Any]) -> tuple[Any, list[str]]:
  2093. def run_with_backward() -> Any:
  2094. result = fn()
  2095. result.sum().backward()
  2096. return result
  2097. return run_and_get_code(run_with_backward)
  2098. def get_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> list[str]:
  2099. """Get the inductor-generated code, but skip any actual compilation or running."""
  2100. from .graph import GraphLowering
  2101. source_codes: list[str] = []
  2102. def save_output_code(code: str) -> None:
  2103. source_codes.append(code)
  2104. def patched_compile_to_module(self: GraphLowering) -> Any:
  2105. class DummyModule:
  2106. """This is empty to replace the generated triton module"""
  2107. def __init__(self) -> None:
  2108. pass
  2109. def call(self, *args: Any, **kwargs: Any) -> None:
  2110. # Don't do anything when called
  2111. pass
  2112. wrapper_code, kernel_code = (
  2113. self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
  2114. )
  2115. # Skip all the actual compiling.
  2116. save_output_code(wrapper_code.value)
  2117. if kernel_code:
  2118. save_output_code(kernel_code.value)
  2119. return DummyModule()
  2120. with (
  2121. mock.patch.object(
  2122. GraphLowering, "compile_to_module", patched_compile_to_module
  2123. ),
  2124. mock.patch.object(GraphLowering, "save_output_code", save_output_code),
  2125. ):
  2126. torch._dynamo.reset()
  2127. # Note the return here is None
  2128. _ = fn(*args, **kwargs)
  2129. return source_codes
  2130. def get_triton_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> str:
  2131. # pyrefly: ignore [bad-argument-type]
  2132. source_codes = get_code(fn, *args, **kwargs)
  2133. # Can have two outputs if backwards was eagerly compiled
  2134. assert 1 <= len(source_codes) <= 2, (
  2135. f"expected one or two code outputs got {len(source_codes)}"
  2136. )
  2137. return source_codes[0]
  2138. def run_and_get_triton_code(
  2139. fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
  2140. ) -> str:
  2141. # pyrefly: ignore [bad-argument-type]
  2142. _, source_codes = run_and_get_code(fn, *args, **kwargs)
  2143. # Can have two outputs if backwards was eagerly compiled
  2144. assert 1 <= len(source_codes) <= 2, (
  2145. f"expected one or two code outputs got {len(source_codes)}"
  2146. )
  2147. return source_codes[0]
  2148. def run_and_get_graph_lowering(
  2149. fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
  2150. ) -> tuple[Any, list[GraphLowering]]:
  2151. from torch._inductor.graph import GraphLowering
  2152. from torch._inductor.output_code import CompiledFxGraph
  2153. real_init = CompiledFxGraph.__init__
  2154. graph_lowerings = []
  2155. def fake_init(*args: Any, **kwargs: Any) -> None:
  2156. real_init(*args, **kwargs)
  2157. graph = args[2]
  2158. assert isinstance(graph, GraphLowering)
  2159. graph_lowerings.append(graph)
  2160. with mock.patch.object(CompiledFxGraph, "__init__", fake_init):
  2161. result = fn(*args, **kwargs)
  2162. return result, graph_lowerings
  2163. @contextlib.contextmanager
  2164. def override_lowering(
  2165. aten_op: Callable[..., Any], override_fn: Callable[..., Any]
  2166. ) -> Iterator[None]:
  2167. """
  2168. Override the lowering of aten_op with override_fn.
  2169. The first argument of override_fn is the original lowering fn.
  2170. """
  2171. from torch._inductor import lowering
  2172. orig_fn = lowering.lowerings[aten_op]
  2173. try:
  2174. lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn)
  2175. yield
  2176. finally:
  2177. lowering.lowerings[aten_op] = orig_fn
  2178. def add_scheduler_init_hook(
  2179. pre_fn: Callable[..., Any], post_fn: Optional[Callable[..., Any]] = None
  2180. ) -> Any:
  2181. """
  2182. Add hook functions to be called at the beginning and end of Scheduler.__init__.
  2183. Used for unit tests.
  2184. """
  2185. from torch._inductor.scheduler import Scheduler
  2186. orig_fn = Scheduler.__init__
  2187. def wrapper(scheduler: Any, nodes: Any) -> Any:
  2188. pre_fn(scheduler, nodes)
  2189. out = orig_fn(scheduler, nodes)
  2190. if post_fn:
  2191. post_fn(scheduler, nodes)
  2192. return out
  2193. return unittest.mock.patch.object(Scheduler, "__init__", wrapper)
  2194. def developer_warning(msg: str) -> None:
  2195. """
  2196. Warnings that will be actionable for PyTorch developers, but not
  2197. end users. Allows us to easily disable them in stable releases but
  2198. keep them on for nightly builds.
  2199. """
  2200. if config.developer_warnings:
  2201. log.warning(msg)
  2202. else:
  2203. log.info(msg)
  2204. def get_benchmark_name() -> Optional[str]:
  2205. """
  2206. An experimental API used only when config.benchmark_kernel is true.
  2207. The benchmark name is only available at codegen time. So we can not
  2208. directly call it in benchmark_all_kernels which is run after codegen.
  2209. The function assumes the argument after --only is the benchmark name.
  2210. It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc
  2211. scripts, this function may return None.
  2212. There are 2 flavors of --only argument we need handle:
  2213. 1. --only model_name
  2214. 2. --only=model_name
  2215. """
  2216. try:
  2217. idx = sys.argv.index("--only")
  2218. if (
  2219. idx + 1 < len(sys.argv)
  2220. and len(sys.argv[idx + 1]) > 0
  2221. and sys.argv[idx + 1][0] != "-"
  2222. ):
  2223. return sys.argv[idx + 1]
  2224. except ValueError:
  2225. pass
  2226. for arg in sys.argv:
  2227. if arg.startswith("--only="):
  2228. return arg[len("--only=") :]
  2229. return None
  2230. def is_ones(items: Sequence[Any]) -> bool:
  2231. return all(x == 1 for x in items)
  2232. def is_zeros(items: Sequence[Any]) -> bool:
  2233. return all(x == 0 for x in items)
  2234. def is_cpu_device(inputs: Sequence[torch.Tensor]) -> bool:
  2235. return all(
  2236. item.device == torch.device("cpu")
  2237. for item in inputs
  2238. if isinstance(item, torch.Tensor)
  2239. )
  2240. def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
  2241. assert isinstance(val, sympy.Expr), (
  2242. "only support sympy.Expr as input to get_sympy_Expr_dtype"
  2243. )
  2244. if val.is_integer: # type: ignore[attr-defined]
  2245. return torch.int64
  2246. else:
  2247. return torch.float64
  2248. @contextlib.contextmanager
  2249. def maybe_profile(should_profile: bool, *args: Any, **kwargs: Any) -> Iterator[Any]:
  2250. if should_profile:
  2251. with torch.profiler.profile(*args, **kwargs) as p:
  2252. yield p
  2253. else:
  2254. yield
  2255. def parallel_num_threads() -> int:
  2256. threads = config.cpp.threads
  2257. if threads < 1:
  2258. threads = torch.get_num_threads()
  2259. return threads
  2260. @functools.cache
  2261. def get_backend_num_stages() -> int:
  2262. from .runtime.triton_helpers import get_backend_options
  2263. options = get_backend_options()
  2264. return options.get("num_stages", 2 if torch.version.hip else 3)
  2265. @functools.cache
  2266. def get_device_tflops(dtype: torch.dtype) -> float:
  2267. """
  2268. We don't want to throw errors in this function. First check to see if the device is in device_info.py,
  2269. then fall back to the inaccurate triton estimation.
  2270. """
  2271. ds_tops = datasheet_tops(
  2272. dtype, is_tf32=torch.backends.cuda.matmul.fp32_precision == "tf32"
  2273. )
  2274. if ds_tops is not None:
  2275. return ds_tops
  2276. from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
  2277. SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (
  2278. 8,
  2279. 0,
  2280. )
  2281. assert dtype in (torch.float16, torch.bfloat16, torch.float32)
  2282. if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"):
  2283. # Triton API change in https://github.com/triton-lang/triton/pull/2293
  2284. from torch._utils_internal import max_clock_rate
  2285. sm_clock = max_clock_rate()
  2286. if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
  2287. return get_max_tensorcore_tflops(dtype, sm_clock)
  2288. if torch.backends.cuda.matmul.fp32_precision == "tf32":
  2289. return get_max_tensorcore_tflops(torch.float32, sm_clock)
  2290. else:
  2291. return get_max_simd_tflops(torch.float32, sm_clock)
  2292. else:
  2293. if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
  2294. return get_max_tensorcore_tflops(dtype)
  2295. if torch.backends.cuda.matmul.fp32_precision == "tf32":
  2296. return get_max_tensorcore_tflops(torch.float32)
  2297. else:
  2298. return get_max_simd_tflops(torch.float32)
  2299. @functools.cache
  2300. def get_gpu_dram_gbps() -> int:
  2301. from triton.testing import get_dram_gbps
  2302. return get_dram_gbps()
  2303. def get_gpu_shared_memory() -> int:
  2304. from triton.runtime import driver
  2305. return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0)
  2306. def get_max_numwarps() -> int:
  2307. if torch.cuda.is_available():
  2308. warp_size = torch.cuda.get_device_properties().warp_size
  2309. # pyrefly: ignore [missing-attribute]
  2310. max_threads_per_block = torch.cuda.get_device_properties().max_threads_per_block
  2311. else:
  2312. # Defaults
  2313. warp_size = 32
  2314. max_threads_per_block = 1024
  2315. return max_threads_per_block // warp_size
  2316. def is_welford_reduction(reduction_type: str) -> bool:
  2317. return reduction_type.startswith("welford")
  2318. def reduction_num_outputs(reduction_type: str) -> int:
  2319. if is_welford_reduction(reduction_type):
  2320. return 3
  2321. elif reduction_type == "online_softmax_reduce":
  2322. return 2
  2323. else:
  2324. return 1
  2325. def is_linux() -> bool:
  2326. return platform.system() == "Linux"
  2327. def is_windows() -> bool:
  2328. return sys.platform == "win32"
  2329. def has_free_symbols(itr: Iterable[Any]) -> bool:
  2330. return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr)
  2331. def is_dynamic(*args: Any) -> bool:
  2332. from . import ir
  2333. for t in args:
  2334. if isinstance(
  2335. t, (ir.TensorBox, ir.StorageBox, ir.BaseView, ir.ComputedBuffer, ir.Buffer)
  2336. ):
  2337. if has_free_symbols(t.maybe_get_size() or ()) or has_free_symbols(
  2338. t.maybe_get_stride() or ()
  2339. ):
  2340. return True
  2341. elif not isinstance(t, ir.IRNode):
  2342. continue
  2343. else:
  2344. raise TypeError(f"unexpected type for is_dynamic {type(t)}")
  2345. return False
  2346. # Placeholder strings used in triton codegen.
  2347. class Placeholder(enum.Enum):
  2348. # The placeholder for the actual name of a triton kernel.
  2349. # e.g. for "def triton_" it would be "triton_"
  2350. KERNEL_NAME = "KERNEL_NAME"
  2351. # The descriptive name of the triton kernel; when unique_kernel_names = False, this
  2352. # placeholder will be replaced with a string with more information.
  2353. DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME"
  2354. def pass_execution_and_save(
  2355. func: Callable[..., Any], gm: GraphModule, inp: Sequence[Any], msg: str
  2356. ) -> None:
  2357. from .pattern_matcher import stable_topological_sort
  2358. with tempfile.NamedTemporaryFile(
  2359. mode="w",
  2360. encoding="utf-8",
  2361. ) as f:
  2362. before_io = io.StringIO()
  2363. after_io = io.StringIO()
  2364. ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp)
  2365. print(f"Before:\n{gm.graph}", file=f)
  2366. print(gm.graph, file=before_io)
  2367. start_time = datetime.now()
  2368. with GraphTransformObserver(gm, msg):
  2369. func(gm.graph)
  2370. time_elapsed = datetime.now() - start_time
  2371. # recompile graph
  2372. stable_topological_sort(gm.graph)
  2373. gm.graph.lint()
  2374. gm.recompile()
  2375. print(f"After:\n{gm.graph}", file=f)
  2376. print(gm.graph, file=after_io)
  2377. t = before_io.getvalue() == after_io.getvalue()
  2378. log.info(
  2379. "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s",
  2380. msg,
  2381. f.name,
  2382. t,
  2383. time_elapsed,
  2384. )
  2385. def is_multi_outputs_template(input_buf: Optional[Union[Buffer, Operation]]) -> bool:
  2386. """
  2387. Check if input buffer is a multi-outputs template buffer
  2388. """
  2389. from . import ir
  2390. return isinstance(input_buf, ir.CppTemplateBuffer) and isinstance(
  2391. input_buf.layout, ir.MultiOutputLayout
  2392. )
  2393. def is_output_of_multi_outputs_template(
  2394. input_buf: Optional[Union[Buffer, Operation]],
  2395. ) -> bool:
  2396. """
  2397. Check if input buffer is a output of multi-outputs template buffer
  2398. """
  2399. from . import ir
  2400. return (
  2401. isinstance(input_buf, ir.MultiOutput)
  2402. and len(input_buf.inputs) == 1
  2403. and is_multi_outputs_template(input_buf.inputs[0]) # type: ignore[arg-type]
  2404. )
  2405. def is_collective(
  2406. node: Optional[Union[Node, Operation]],
  2407. op: Optional[torch._ops.OperatorBase] = None,
  2408. ) -> bool:
  2409. if node is None:
  2410. return False
  2411. from . import ir
  2412. return (
  2413. isinstance(node, ir._CollectiveKernel)
  2414. and not isinstance(node, ir._WaitKernel)
  2415. and (op is None or node.op_overload is op)
  2416. ) or (
  2417. # TODO: this is a temporary solution to ensure that we can identify torchrec's
  2418. # communication ops. But in order to allow better communication and computation
  2419. # overlap, torchrec's communication ops should be not used.
  2420. type(node) is ir.FallbackKernel
  2421. and (
  2422. # NOTE: the `hasattr()` check is to bypass errors such as the following:
  2423. # AttributeError: '_OpNamespace' 'torchrec' object has no attribute 'all_to_all_single'
  2424. (
  2425. hasattr(torch.ops.torchrec, "all_to_all_single")
  2426. and node.op_overload == torch.ops.torchrec.all_to_all_single.default
  2427. )
  2428. or (
  2429. hasattr(torch.ops.torchrec, "all_gather_into_tensor")
  2430. and node.op_overload
  2431. == torch.ops.torchrec.all_gather_into_tensor.default
  2432. )
  2433. or (
  2434. hasattr(torch.ops.torchrec, "reduce_scatter_tensor")
  2435. and node.op_overload == torch.ops.torchrec.reduce_scatter_tensor.default
  2436. )
  2437. )
  2438. )
  2439. def is_wait(node: Optional[Union[IRNode, Operation]]) -> bool:
  2440. from . import ir
  2441. return type(node) is ir._WaitKernel
  2442. def contains_collective(
  2443. snode: BaseSchedulerNode,
  2444. filter_fn: Optional[Callable[[BaseSchedulerNode], bool]] = None,
  2445. ) -> bool:
  2446. from torch._inductor.scheduler import GroupedSchedulerNode
  2447. if isinstance(snode, GroupedSchedulerNode):
  2448. return any(contains_collective(x) for x in snode.snodes)
  2449. return is_collective(snode.node) and (filter_fn is None or filter_fn(snode))
  2450. def contains_wait(snode: BaseSchedulerNode) -> bool:
  2451. from torch._inductor.scheduler import GroupedSchedulerNode
  2452. if isinstance(snode, GroupedSchedulerNode):
  2453. return any(contains_wait(x) for x in snode.snodes)
  2454. else:
  2455. return is_wait(snode.node)
  2456. def is_fallback_op(
  2457. node: Optional[Operation],
  2458. op: Union[torch._ops.OpOverload, Collection[torch._ops.OpOverload]],
  2459. ) -> bool:
  2460. from . import ir
  2461. if isinstance(op, torch._ops.OpOverload):
  2462. op = [op]
  2463. return isinstance(node, ir.FallbackKernel) and node.op_overload in op
  2464. def buf_name_to_fused_snode(
  2465. buf_name: str, name_to_buf: dict[str, Any], name_to_fused_node: dict[str, Any]
  2466. ) -> Any:
  2467. return name_to_fused_node[name_to_buf[buf_name].defining_op.get_name()]
  2468. def find_recursive_deps_of_node(
  2469. snode: BaseSchedulerNode,
  2470. collected_node_set: MutableSet[BaseSchedulerNode],
  2471. name_to_buf: dict[str, SchedulerBuffer],
  2472. name_to_fused_node: dict[str, BaseSchedulerNode],
  2473. criteria_cb: Callable[[Any], bool] = lambda snode: False,
  2474. ) -> None:
  2475. if criteria_cb(snode):
  2476. return
  2477. collected_node_set.add(snode)
  2478. for dep in snode.unmet_dependencies:
  2479. defining_op_for_dep = buf_name_to_fused_snode(
  2480. dep.name, name_to_buf, name_to_fused_node
  2481. )
  2482. if defining_op_for_dep in collected_node_set:
  2483. continue
  2484. find_recursive_deps_of_node(
  2485. defining_op_for_dep,
  2486. collected_node_set,
  2487. name_to_buf,
  2488. name_to_fused_node,
  2489. criteria_cb=criteria_cb,
  2490. )
  2491. def find_recursive_users_of_node(
  2492. snode: BaseSchedulerNode,
  2493. collected_node_set: MutableSet[BaseSchedulerNode],
  2494. name_to_buf: dict[str, SchedulerBuffer],
  2495. name_to_fused_node: dict[str, BaseSchedulerNode],
  2496. criteria_cb: Callable[[Any], bool] = lambda snode: False,
  2497. ) -> None:
  2498. if criteria_cb(snode):
  2499. return
  2500. collected_node_set.add(snode)
  2501. for o in snode.get_outputs():
  2502. for user in o.users:
  2503. assert user.node is not None
  2504. if user.node.get_name() == "OUTPUT":
  2505. continue
  2506. if user.node.get_name() not in name_to_fused_node:
  2507. continue
  2508. user_op = name_to_fused_node[user.node.get_name()]
  2509. if user_op in collected_node_set:
  2510. continue
  2511. find_recursive_users_of_node(
  2512. user_op,
  2513. collected_node_set,
  2514. name_to_buf,
  2515. name_to_fused_node,
  2516. criteria_cb=criteria_cb,
  2517. )
  2518. def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int) -> int:
  2519. "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)"
  2520. num_rng_seed_offset_inputs = (
  2521. 2 if torch._functorch.config.functionalize_rng_ops else 0
  2522. )
  2523. # AOT won't lift any parameters if we're inlining NN Modules
  2524. # however desugaring subclasses will still add arguments
  2525. # resulted in extra fixed inputs https://github.com/pytorch/pytorch/issues/130502
  2526. return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs
  2527. def count_tangents(fx_g: torch.fx.GraphModule) -> int:
  2528. """
  2529. Infers which inputs are static for a backwards graph
  2530. """
  2531. def is_saved_tensor(x: Node) -> bool:
  2532. return (
  2533. "tangents" not in x.name
  2534. and "bwd_seed" not in x.name
  2535. and "bwd_base_offset" not in x.name
  2536. and "bwd_rng_state" not in x.name
  2537. )
  2538. arg_count = 0
  2539. static_arg_idxs = []
  2540. for n in fx_g.graph.nodes:
  2541. if n.op == "placeholder":
  2542. if is_saved_tensor(n):
  2543. static_arg_idxs.append(arg_count)
  2544. arg_count += 1
  2545. assert static_arg_idxs == list(range(len(static_arg_idxs)))
  2546. return len(static_arg_idxs)
  2547. @dataclasses.dataclass
  2548. class BoxedBool:
  2549. value: bool
  2550. def __bool__(self) -> bool:
  2551. return self.value
  2552. @staticmethod
  2553. def disable(obj: Any) -> Union[BoxedBool, bool]:
  2554. if isinstance(obj, BoxedBool):
  2555. obj.value = False
  2556. return obj
  2557. return False
  2558. @contextlib.contextmanager
  2559. def collect_defined_kernels(kernel_list: list[str]) -> Iterator[None]:
  2560. from .codegen.wrapper import PythonWrapperCodegen
  2561. orig_define_kernel = PythonWrapperCodegen.define_kernel
  2562. def define_kernel(
  2563. self: PythonWrapperCodegen,
  2564. kernel_name: str,
  2565. kernel_code: str,
  2566. metadata: Optional[str] = None,
  2567. gpu: bool = True,
  2568. cpp_definition: Optional[str] = None,
  2569. ) -> Any:
  2570. kernel_list.append(kernel_code)
  2571. return orig_define_kernel(
  2572. self, kernel_name, kernel_code, metadata, gpu, cpp_definition
  2573. )
  2574. with mock.patch.object(PythonWrapperCodegen, "define_kernel", define_kernel):
  2575. yield
  2576. def get_cloned_parameter_buffer_name(name: str) -> str:
  2577. return name + "__original__"
  2578. def is_gpu(device: Optional[str]) -> bool:
  2579. return device in GPU_TYPES
  2580. def is_rocm() -> bool:
  2581. """Check if we're running on ROCm/HIP platform."""
  2582. return torch.version.hip is not None
  2583. def device_need_guard(device: str) -> bool:
  2584. return device != "mps" and is_gpu(device) # TODO: MPS does not expose streams now
  2585. def needs_fallback_due_to_atomic_add_limitations(dtype: torch.dtype) -> bool:
  2586. if dtype == torch.bfloat16 and torch.cuda.is_available():
  2587. return torch.cuda.get_device_capability() < (9, 0)
  2588. elif dtype == torch.bfloat16 and torch.xpu.is_available():
  2589. return True
  2590. else:
  2591. return dtype in (torch.int64, torch.bool)
  2592. def use_scatter_fallback(
  2593. op_overload: torch._ops.OpOverload,
  2594. reduction_type: Optional[str],
  2595. self_dtype: torch.dtype,
  2596. src_dtype: torch.dtype,
  2597. src_device_type: str,
  2598. src_is_tensor: bool,
  2599. ) -> bool:
  2600. if (
  2601. op_overload.overloadpacket
  2602. in (torch.ops.aten.scatter_reduce_, torch.ops.aten.scatter_reduce)
  2603. and reduction_type is None
  2604. ):
  2605. return False
  2606. reduce_ty = (
  2607. "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum"
  2608. )
  2609. return (
  2610. reduction_type not in (None, reduce_ty)
  2611. or (
  2612. src_is_tensor
  2613. and is_gpu(src_device_type)
  2614. and needs_fallback_due_to_atomic_add_limitations(src_dtype)
  2615. )
  2616. or (
  2617. op_overload.overloadpacket == torch.ops.aten.scatter_reduce_
  2618. and reduction_type == "sum"
  2619. and src_is_tensor
  2620. and src_device_type == "cpu"
  2621. and config.cpp.fallback_scatter_reduce_sum
  2622. and (config.cpp.dynamic_threads or parallel_num_threads() != 1)
  2623. )
  2624. or (reduction_type == reduce_ty and self_dtype in (torch.bool, torch.int64))
  2625. or torch.are_deterministic_algorithms_enabled()
  2626. )
  2627. def dump_node_schedule(node_schedule: Sequence[BaseSchedulerNode]) -> None:
  2628. """
  2629. An API that can be used in pdb to dump a node_schedule.
  2630. Right mainly dump the read/write dependencies but can add more as needed.
  2631. """
  2632. from torch._inductor.codegen.simd import DisableReduction, EnableReduction
  2633. from torch._inductor.scheduler import SchedulerNode
  2634. print(f"Node schedule with {len(node_schedule)} nodes")
  2635. for idx, node in enumerate(node_schedule):
  2636. print(f" {idx:3}:")
  2637. # pyrefly: ignore [unnecessary-comparison]
  2638. if node is EnableReduction:
  2639. print("enable reduction")
  2640. # pyrefly: ignore [unnecessary-comparison]
  2641. elif node is DisableReduction:
  2642. print("disable reduction")
  2643. elif isinstance(node, SchedulerNode):
  2644. is_red = node.is_reduction()
  2645. print(f"{'red' if is_red else 'pw'} scheduler node")
  2646. if is_red:
  2647. assert node.node is not None
  2648. print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined]
  2649. print("ReadDep:")
  2650. for dep in node.read_writes.reads:
  2651. print(dep)
  2652. print("WriteDep:")
  2653. for dep in node.read_writes.writes:
  2654. print(dep)
  2655. else:
  2656. raise RuntimeError(f"Unrecognized node type: {type(node)}")
  2657. def tensor_is_aligned(tensor: torch.Tensor) -> bool:
  2658. # See Note: [Input Alignment handling in Inductor]
  2659. # Right now, we don't try to guard on the alignment of the storage offset.
  2660. # When this comment was written, non-symbolic storage_offsets are not guarded on
  2661. # but symbolic storage_offsets are. For consistency, we suppress guard creation
  2662. # upon performing this check: that ensures that we don't add recompiles when we
  2663. # add this logic.
  2664. from torch.fx.experimental.symbolic_shapes import statically_known_true
  2665. return statically_known_true(
  2666. (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % GPU_ALIGN_BYTES == 0
  2667. )
  2668. def should_assume_input_aligned(example_input: torch.Tensor) -> bool:
  2669. # See Note: [Input Alignment handling in Inductor]
  2670. # right now, we only care about alignment for cuda tensors.
  2671. if not is_gpu(example_input.device.type):
  2672. return False
  2673. return config.assume_aligned_inputs or tensor_is_aligned(example_input)
  2674. def maybe_get_suppress_shape_guards_ctx() -> contextlib.AbstractContextManager[None]:
  2675. # Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards()
  2676. # If it's not available, return a nullcontext.
  2677. # If we're dealing with cudagraphs, we might not have a tracing_context
  2678. tracing_context = torch._guards.TracingContext.try_get()
  2679. if not tracing_context:
  2680. return contextlib.nullcontext()
  2681. # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode
  2682. if not tracing_context.fake_mode or not tracing_context.fake_mode.shape_env:
  2683. return contextlib.nullcontext()
  2684. shape_env = tracing_context.fake_mode.shape_env
  2685. return shape_env.suppress_guards()
  2686. def run_and_get_cpp_code(
  2687. fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
  2688. ) -> tuple[_T, str]:
  2689. # We use the patch context manager instead of using it as a decorator.
  2690. # In this way, we can ensure that the attribute is patched and unpatched correctly
  2691. # even if this run_and_get_cpp_code function is called multiple times.
  2692. with unittest.mock.patch.object(config, "debug", True):
  2693. torch._dynamo.reset()
  2694. import io
  2695. import logging
  2696. log_capture_string = io.StringIO()
  2697. ch = logging.StreamHandler(log_capture_string)
  2698. from torch._inductor.codecache import output_code_log
  2699. output_code_log.addHandler(ch)
  2700. prev_level = output_code_log.level
  2701. output_code_log.setLevel(logging.DEBUG)
  2702. result = fn(*args, **kwargs)
  2703. s = log_capture_string.getvalue()
  2704. output_code_log.setLevel(prev_level)
  2705. output_code_log.removeHandler(ch)
  2706. return result, s
  2707. def shape_env_from_inputs(inputs: Sequence[InputType]) -> Optional[ShapeEnv]:
  2708. fake_mode = detect_fake_mode(inputs)
  2709. # TODO(voz): It would be nice to enable this assert, but there are lots of tests that
  2710. # pass in real inputs for now.
  2711. # if len(inputs) > 0:
  2712. # assert fake_mode is not None, breakpoint()
  2713. if fake_mode is not None:
  2714. return fake_mode.shape_env
  2715. # When there are no tensor inputs, get shape_env from the first SymInt.
  2716. for input in inputs:
  2717. if isinstance(input, torch.SymInt):
  2718. return input.node.shape_env
  2719. # Check tensor sizes and strides for SymInt values
  2720. if isinstance(input, torch.Tensor):
  2721. for size in input.size():
  2722. if isinstance(size, torch.SymInt):
  2723. return size.node.shape_env
  2724. for stride in input.stride():
  2725. if isinstance(stride, torch.SymInt):
  2726. return stride.node.shape_env
  2727. # TODO(voz): Should we always have one anyway?
  2728. return None
  2729. def align_inputs_from_check_idxs(
  2730. model: Callable[[list[InputType]], _T],
  2731. inputs_to_check: Sequence[int],
  2732. mutated_input_idxs: OrderedSet[int],
  2733. ) -> Callable[[list[InputType]], _T]:
  2734. if len(inputs_to_check) == 0:
  2735. return model
  2736. def run(new_inputs: list[InputType]) -> Any:
  2737. old_tensors, new_tensors = copy_misaligned_inputs(
  2738. new_inputs, inputs_to_check, mutated_input_idxs
  2739. )
  2740. out = model(new_inputs)
  2741. # If a mutated tensor was cloned to be aligned, we need to reflect back the mutation to the
  2742. # original tensor.
  2743. if len(old_tensors):
  2744. torch._foreach_copy_(old_tensors, new_tensors)
  2745. return out
  2746. return run
  2747. def clone_preserve_strides(x: torch.Tensor) -> torch.Tensor:
  2748. if 0 in x.size():
  2749. # Short-circuits if the shape has no elements
  2750. needed_size = 0
  2751. else:
  2752. needed_size = (
  2753. sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
  2754. )
  2755. buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
  2756. return torch.as_strided(buffer, x.size(), x.stride())
  2757. def copy_misaligned_inputs(
  2758. new_inputs: list[InputType],
  2759. check_inputs_idxs: Sequence[int],
  2760. return_pair_idxs: Optional[OrderedSet[int]] = None,
  2761. ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
  2762. """
  2763. Clones misaligned tensors which we inferred were aligned. Returns a tuple of [old_tensors], [new_tensors] for every
  2764. cloned tensor which is in `return_pair_idxs`.
  2765. """
  2766. old_tensors: list[torch.Tensor] = []
  2767. new_tensors: list[torch.Tensor] = []
  2768. # hoist above loop because this is on the hot path
  2769. ret_pair_defined = return_pair_idxs is not None
  2770. for i in check_inputs_idxs:
  2771. _inp = new_inputs[i]
  2772. assert isinstance(_inp, torch.Tensor), (
  2773. f"Expected tensors only, but got: {type(_inp)}"
  2774. )
  2775. if _inp.data_ptr() % ALIGNMENT:
  2776. new_inputs[i] = clone_preserve_strides(_inp)
  2777. if ret_pair_defined and i in return_pair_idxs: # type: ignore[operator]
  2778. old_tensors.append(_inp)
  2779. new_tensors.append(new_inputs[i]) # type: ignore[arg-type]
  2780. return old_tensors, new_tensors
  2781. def remove_unaligned_input_idxs(
  2782. inputs: Sequence[InputType],
  2783. static_input_idxs: Sequence[int],
  2784. ) -> Sequence[int]:
  2785. """
  2786. We require all inputs to be aligned, so introduce a copy for any
  2787. that aren't.
  2788. """
  2789. aligned_static_input_idxs = []
  2790. for idx in static_input_idxs:
  2791. input = inputs[idx]
  2792. if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0:
  2793. aligned_static_input_idxs.append(idx)
  2794. if len(aligned_static_input_idxs) != len(static_input_idxs):
  2795. return aligned_static_input_idxs
  2796. return static_input_idxs
  2797. def expr_fits_within_32bit(e: sympy.Expr) -> bool:
  2798. from .virtualized import V
  2799. int_max = torch.iinfo(torch.int32).max
  2800. size_hint = V.graph.sizevars.size_hint
  2801. has_hint = V.graph.sizevars.shape_env.has_hint
  2802. if config.assume_32bit_indexing:
  2803. V.graph.sizevars.check_leq(e, int_max) # type: ignore[arg-type]
  2804. return True
  2805. # Allow for unhinted e as long as we can still statically prove
  2806. # (e.g., via ValueRanges) that it is still in bounds
  2807. if V.graph.sizevars.statically_known_true(e <= int_max):
  2808. return True
  2809. # AOTI doesn't guard on < 2**32, so checking hints isn't a viable option,
  2810. # in case the hinted value is < 2**32, but the allowed range is larger.
  2811. # However, to prevent possible perf regressions on pre-existing AOTI models
  2812. # which don't set an upper bound on the valid range, we'll skip the check.
  2813. # To recap:
  2814. # - If using AOTI:
  2815. # - If allowed range has no upper bound, then check the hint to determine
  2816. # whether this fits in int32
  2817. # - If allowed range does have an upper bound, then obey the upper bound
  2818. # (check whether upper bound < int32_max) without checking the hint.
  2819. if V.aot_compilation:
  2820. # check whether value has an upper bound (1e20 is > INT64_MAX, assume
  2821. # there is no upper bound if it can be larger than 1e20)
  2822. if V.graph.sizevars.statically_known_true(e < 1e20):
  2823. # if so, then assume int_max < upper bound < inf
  2824. # so this could potentially have int64 values
  2825. return False
  2826. # Otherwise, the hint MUST exist and be in range
  2827. return has_hint(e) and size_hint(e) <= int_max
  2828. def set_tracing_context_output_strides(
  2829. example_inputs: Sequence[Any], compiled_graph: CompiledFxGraph
  2830. ) -> None:
  2831. # Return the output strides to the caller via TracingContext
  2832. context = torch._guards.TracingContext.try_get()
  2833. if context is not None and context.output_strides is not None:
  2834. assert len(context.output_strides) == 0
  2835. shape_env = shape_env_from_inputs(example_inputs)
  2836. assert compiled_graph.output_strides is not None
  2837. for exprs in compiled_graph.output_strides:
  2838. if exprs is None:
  2839. context.output_strides.append(None)
  2840. else:
  2841. fakify_first_call = False
  2842. if ctx := torch._guards.TracingContext.try_get():
  2843. fakify_first_call = ctx.fakify_first_call
  2844. def map_expr(e: Any) -> Union[float, int, SymInt, SymFloat, SymBool]:
  2845. if shape_env is None:
  2846. return int(e)
  2847. if fakify_first_call:
  2848. return shape_env.deserialize_symexpr(e)
  2849. return shape_env.evaluate_symexpr(e)
  2850. context.output_strides.append(
  2851. tuple(map_expr(e) for e in exprs) # type: ignore[misc]
  2852. )
  2853. def should_use_remote_fx_graph_cache() -> bool:
  2854. if config.fx_graph_remote_cache is not None:
  2855. return config.fx_graph_remote_cache
  2856. if not config.is_fbcode():
  2857. return False
  2858. if torch._utils_internal.is_fb_unit_test():
  2859. return False
  2860. try:
  2861. from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
  2862. except ModuleNotFoundError:
  2863. return False
  2864. return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
  2865. "pytorch/remote_cache:fx_graph_memcache_version"
  2866. )
  2867. def normalize_name(name: str) -> str:
  2868. return re.sub(r"[^a-zA-Z0-9_]", "_", name)
  2869. # correct cases where Triton types names don't match PyTorch
  2870. _triton_type_mapping = {
  2871. "tl.bool": "tl.int1",
  2872. "tl.float8_e4m3fn": "tl.float8e4nv",
  2873. "tl.float8_e5m2": "tl.float8e5",
  2874. "tl.float8_e4m3fnuz": "tl.float8e4b8",
  2875. "tl.float8_e5m2fnuz": "tl.float8e5b16",
  2876. # TODO: remove when support is added in triton
  2877. # https://github.com/triton-lang/triton/issues/6054
  2878. "tl.float8_e8m0fnu": "tl.uint8",
  2879. "tl.float4_e2m1fn_x2": "tl.uint8",
  2880. }
  2881. _torch_triton_mapping = {v: k for k, v in _triton_type_mapping.items()}
  2882. _triton_type_re = re.compile(r"^.*[.]")
  2883. def triton_type(dtype: torch.dtype) -> str:
  2884. """Convert torch.dtype to triton type"""
  2885. triton_type_name = _triton_type_re.sub("tl.", str(dtype))
  2886. return _triton_type_mapping.get(triton_type_name, triton_type_name)
  2887. def triton_type_to_torch(dtype: str) -> torch.dtype:
  2888. adjusted_type = _torch_triton_mapping.get(dtype, dtype)
  2889. type_name = adjusted_type.replace("tl.", "")
  2890. out_dtype = getattr(torch, type_name)
  2891. assert isinstance(out_dtype, torch.dtype)
  2892. return out_dtype
  2893. def is_same_tensor(data: torch.Tensor, value: torch.Tensor) -> bool:
  2894. return (
  2895. not data.is_mkldnn
  2896. and data.size() == value.size()
  2897. and data.stride() == value.stride()
  2898. and data.dtype == value.dtype
  2899. and data.device == value.device
  2900. and data.untyped_storage().data_ptr() == value.untyped_storage().data_ptr()
  2901. and data.storage_offset() == value.storage_offset()
  2902. )
  2903. def is_same_mkldnn_tensor(data: torch.Tensor, value: torch.Tensor) -> bool:
  2904. return (
  2905. data.is_mkldnn
  2906. and data.size() == value.size()
  2907. and data.dtype == value.dtype
  2908. and data.device == value.device
  2909. and torch.ops.mkldnn.data_ptr(data) == torch.ops.mkldnn.data_ptr(value)
  2910. )
  2911. @functools.cache
  2912. def boolean_ops() -> tuple[str, ...]:
  2913. return (
  2914. "isinf",
  2915. "isnan",
  2916. "logical_not",
  2917. "logical_and",
  2918. "signbit",
  2919. "and_",
  2920. "le",
  2921. "lt",
  2922. "ge",
  2923. "gt",
  2924. "eq",
  2925. "ne",
  2926. "or_", # TODO should remove this op
  2927. "xor",
  2928. )
  2929. @dataclasses.dataclass
  2930. class OpDtypeRule:
  2931. type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND
  2932. override_return_dtype: Optional[torch.dtype]
  2933. op_dtype_propagation_rules: dict[str, OpDtypeRule] = {}
  2934. def register_op_dtype_propagation_rules(
  2935. name: str,
  2936. type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
  2937. override_return_dtype: Optional[torch.dtype],
  2938. ) -> None:
  2939. op_dtype_propagation_rules[name] = OpDtypeRule(
  2940. type_promotion_kind, override_return_dtype
  2941. )
  2942. op_requires_libdevice_fp64: OrderedSet[str] = OrderedSet()
  2943. def register_op_requires_libdevice_fp64(name: str) -> None:
  2944. op_requires_libdevice_fp64.add(name)
  2945. def get_current_backend(device_type: Optional[str] = None) -> str:
  2946. from torch._inductor.virtualized import V
  2947. if not device_type:
  2948. device_type = V.graph.get_current_device_or_throw().type
  2949. if device_type == "cpu":
  2950. return config.cpu_backend
  2951. elif device_type == "mps":
  2952. return "mps"
  2953. elif device_type == "xpu":
  2954. return config.xpu_backend
  2955. else:
  2956. return config.cuda_backend
  2957. def upcast_compute_type(dtype: torch.dtype) -> torch.dtype:
  2958. """Maybe upcast [b]float16 to float32"""
  2959. if (
  2960. dtype in (torch.float16, torch.bfloat16)
  2961. and config.triton.codegen_upcast_to_fp32
  2962. and get_current_backend() == "triton"
  2963. ):
  2964. return torch.float32
  2965. return dtype
  2966. KeyType = TypeVar("KeyType")
  2967. ValType = TypeVar("ValType")
  2968. class ScopedDict(MutableMapping[KeyType, ValType]):
  2969. """
  2970. A dictionary-like object that allows for scoped updates. It maintains
  2971. an original dictionary and a set of new items that can override
  2972. the original items within the scope. The original dictionary is
  2973. unmodified.
  2974. """
  2975. def __init__(self, original_dict: Mapping[KeyType, ValType]):
  2976. self.original_dict = original_dict
  2977. self.new_items: dict[KeyType, ValType] = {}
  2978. def __getitem__(self, key: KeyType) -> ValType:
  2979. if key in self.new_items:
  2980. return self.new_items[key]
  2981. return self.original_dict[key]
  2982. def __setitem__(self, key: KeyType, value: ValType) -> None:
  2983. self.new_items[key] = value
  2984. def __contains__(self, key: object) -> bool:
  2985. return key in self.new_items or key in self.original_dict
  2986. def get(self, key: KeyType, default: Optional[ValType] = None) -> Optional[ValType]: # type: ignore[override]
  2987. if key in self.new_items:
  2988. return self.new_items[key]
  2989. return self.original_dict.get(key, default)
  2990. def __len__(self) -> int:
  2991. n = len(self.original_dict)
  2992. for k in self.new_items:
  2993. if k not in self.original_dict:
  2994. n += 1
  2995. return n
  2996. def __iter__(self) -> Iterator[KeyType]:
  2997. yield from self.original_dict
  2998. for k in self.new_items:
  2999. if k not in self.original_dict:
  3000. yield k
  3001. def __bool__(self) -> bool:
  3002. return bool(self.original_dict or self.new_items)
  3003. def __delitem__(self, key: KeyType) -> None:
  3004. raise NotImplementedError
  3005. @dataclass_transform(frozen_default=True)
  3006. def ir_dataclass(cls: Optional[type[Any]] = None, /, *, frozen: bool = True) -> Any:
  3007. def wrap(cls: _T) -> _T:
  3008. return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload]
  3009. if cls is None:
  3010. return wrap
  3011. return wrap(cls)
  3012. def get_donated_idxs() -> Optional[list[int]]:
  3013. tracing_context = torch._guards.TracingContext.try_get()
  3014. if tracing_context is not None and tracing_context.fw_metadata:
  3015. return tracing_context.fw_metadata.bw_donated_idxs
  3016. return None
  3017. class TritonAttrsDescriptorVersion(enum.Enum):
  3018. V0_NO_TRITON = 0
  3019. V1_COMPILER = 1 # triton.compiler.compiler.AttrsDescriptor
  3020. V2_BACKENDS = 2 # triton.backends.compiler.AttrsDescriptor
  3021. V3_BACKENDS_TUPLE = (
  3022. 3 # triton.backends.compiler.AttrsDescriptor, but with tuple support
  3023. )
  3024. V4_DICT = 4 # a raw dict
  3025. @functools.cache
  3026. def get_triton_attrs_descriptor_version() -> TritonAttrsDescriptorVersion:
  3027. if importlib.util.find_spec("triton") is None:
  3028. return TritonAttrsDescriptorVersion.V0_NO_TRITON
  3029. import triton.backends.compiler
  3030. import triton.compiler.compiler
  3031. if hasattr(triton.backends.compiler, "AttrsDescriptor"):
  3032. # Triton 3.2.0
  3033. # AttrsDescriptor was moved from triton.compiler.compiler to triton.backends.compiler.
  3034. # AttrsDescriptor and its serialization format were also changed.
  3035. # TODO: implement V3_BACKENDS_TUPLE
  3036. # On Dec 9, 2024, tuple support (triton #5220) was implemented and breaks handling.
  3037. # We don't have a way to detect this (and haven't implemented this version)
  3038. return TritonAttrsDescriptorVersion.V2_BACKENDS
  3039. elif hasattr(triton.compiler.compiler, "AttrsDescriptor"):
  3040. # Triton 3.0.0
  3041. return TritonAttrsDescriptorVersion.V1_COMPILER
  3042. else:
  3043. # After Jan 1, 2025
  3044. # AttrsDescriptor was removed and replaced with a raw dict.
  3045. return TritonAttrsDescriptorVersion.V4_DICT
  3046. def triton_version_uses_attrs_dict() -> bool:
  3047. return get_triton_attrs_descriptor_version() == TritonAttrsDescriptorVersion.V4_DICT
  3048. def get_op_names(op: torch._ops.OperatorBase) -> tuple[str, str]:
  3049. op_overload_packet_name: str = op.name()
  3050. op_overload_name = (
  3051. f"{op_overload_packet_name}.{op._overloadname}"
  3052. if isinstance(op, torch._ops.OpOverload)
  3053. else op_overload_packet_name
  3054. )
  3055. return op_overload_packet_name, op_overload_name
  3056. def _fx_node_is_input_dependent_cudagraph_unsafe(fx_node: torch.fx.Node) -> bool:
  3057. """
  3058. Check if an FX node is cudagraph-unsafe based on its input arguments.
  3059. Some ops are only cudagraph-unsafe depending on their inputs (e.g., index_put
  3060. with boolean indices triggers .nonzero() during capture, but integer indices
  3061. are safe).
  3062. """
  3063. from torch.fx.operator_schemas import normalize_function
  3064. target = fx_node.target
  3065. if not isinstance(target, torch._ops.OpOverload):
  3066. return False
  3067. # index_put with boolean indices triggers .nonzero() during capture
  3068. if target in (
  3069. torch.ops.aten.index_put.default,
  3070. torch.ops.aten.index_put_.default,
  3071. torch.ops.aten._unsafe_index_put.default,
  3072. ):
  3073. normalized = normalize_function(
  3074. target, fx_node.args, fx_node.kwargs, normalize_to_only_use_kwargs=True
  3075. )
  3076. if normalized is not None:
  3077. _, kwargs = normalized
  3078. indices = kwargs["indices"]
  3079. for idx in indices:
  3080. if idx is not None and idx.meta["val"].dtype in (
  3081. torch.bool,
  3082. torch.uint8,
  3083. ):
  3084. return True
  3085. return False
  3086. def is_cudagraph_unsafe_fx_node(fx_node: torch.fx.Node) -> bool:
  3087. """
  3088. Check if an FX node is cudagraph-unsafe.
  3089. This includes:
  3090. - Ops in FORBIDDEN_CUDAGRAPH_OPS (CPU sync, dynamic alloc, etc.)
  3091. - Ops with the cudagraph_unsafe tag
  3092. - Input-dependent unsafe ops (e.g., index_put with boolean indices)
  3093. - Ops with sparse tensor outputs
  3094. """
  3095. target = fx_node.target
  3096. # Check against the forbidden ops set
  3097. if str(target) in FORBIDDEN_CUDAGRAPH_OPS:
  3098. return True
  3099. # Check for cudagraph_unsafe tag
  3100. if (
  3101. isinstance(target, torch._ops.OpOverload)
  3102. and torch._C.Tag.cudagraph_unsafe in target.tags # type: ignore[attr-defined]
  3103. ):
  3104. return True
  3105. # Check for input-dependent unsafety
  3106. if _fx_node_is_input_dependent_cudagraph_unsafe(fx_node):
  3107. return True
  3108. # Check for sparse tensor outputs
  3109. if (val := fx_node.meta.get("val")) is not None:
  3110. vals = [val] if not isinstance(val, (list, tuple)) else val
  3111. for v in vals:
  3112. if isinstance(v, torch.Tensor) and v.is_sparse:
  3113. return True
  3114. return False
  3115. def is_cudagraph_unsafe_op(node: Operation) -> bool:
  3116. """
  3117. Returns True if the node is an op that is not cudagraphable.
  3118. This includes:
  3119. - Ops in FORBIDDEN_CUDAGRAPH_OPS (CPU sync, dynamic alloc, etc.)
  3120. - Ops with the cudagraph_unsafe tag
  3121. - index_put_ with boolean indices (triggers .nonzero() during capture)
  3122. - Control flow nodes (Conditional, WhileLoop)
  3123. - Ops with sparse tensor outputs
  3124. """
  3125. from . import ir
  3126. # Control flow nodes are cudagraph-unsafe
  3127. if isinstance(node, (ir.Conditional, ir.WhileLoop)):
  3128. return True
  3129. if not isinstance(node, (ir.FallbackKernel, ir.ExternKernel)):
  3130. return False
  3131. fx_node = getattr(node, "fx_node", None)
  3132. if fx_node is not None and is_cudagraph_unsafe_fx_node(fx_node):
  3133. return True
  3134. return False
  3135. def get_ld_library_path() -> str:
  3136. path = os.environ.get("LD_LIBRARY_PATH", "")
  3137. if config.is_fbcode():
  3138. from libfb.py.parutil import get_runtime_path
  3139. runtime_path = get_runtime_path()
  3140. if runtime_path:
  3141. lib_path = os.path.join(runtime_path, "runtime", "lib")
  3142. path = os.pathsep.join([lib_path, path]) if path else lib_path
  3143. return path
  3144. def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
  3145. from torch._inductor.codegen.wrapper import SubgraphPythonWrapperCodegen
  3146. return (
  3147. isinstance(wrapper, SubgraphPythonWrapperCodegen)
  3148. and wrapper.partition_signatures is not None
  3149. )
  3150. def is_using_cudagraph_partition() -> bool:
  3151. return (
  3152. torch._inductor.config.triton.cudagraphs
  3153. or _unstable_customized_partition_wrapper.wrapper is not None
  3154. ) and torch._inductor.config.graph_partition
  3155. def dtype_from_size(size: int) -> torch.dtype:
  3156. from .virtualized import V
  3157. if V.graph.sizevars.statically_known_lt(
  3158. size, 2**31
  3159. ) and V.graph.sizevars.statically_known_geq(size, -(2**31)):
  3160. return torch.int32
  3161. else:
  3162. return torch.int64
  3163. SUPPORTED_MKLDNN_DEVICES = ("cpu", "xpu")
  3164. def is_mkldnn_bf16_supported(device_type: str) -> bool:
  3165. """
  3166. Returns True if the device supports MKL-DNN BF16.
  3167. """
  3168. if device_type == "cpu":
  3169. return torch.ops.mkldnn._is_mkldnn_bf16_supported()
  3170. elif "xpu" in device_type:
  3171. # match "xpu", "xpu:0", "xpu:1", etc.
  3172. return True
  3173. return False
  3174. def is_mkldnn_fp16_supported(device_type: str) -> bool:
  3175. """
  3176. Returns True if the device supports MKL-DNN FP16.
  3177. """
  3178. if device_type == "cpu":
  3179. return torch.ops.mkldnn._is_mkldnn_fp16_supported()
  3180. elif "xpu" in device_type:
  3181. # match "xpu", "xpu:0", "xpu:1", etc.
  3182. return True
  3183. return False
  3184. def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str:
  3185. widths = [len(str(e)) for e in headers]
  3186. for row in elements:
  3187. assert len(row) == len(headers)
  3188. for i, e in enumerate(row):
  3189. widths[i] = max(widths[i], len(str(e)))
  3190. lines = []
  3191. lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths)))
  3192. # widths whitespace horizontal separators
  3193. total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1)
  3194. lines.append("-" * total_width)
  3195. for row in elements:
  3196. lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths)))
  3197. return "\n".join(lines)
  3198. def zip_dicts(
  3199. dict1: Mapping[KeyType, ValType],
  3200. dict2: Mapping[KeyType, ValType],
  3201. d1_default: ValType | None = None,
  3202. d2_default: ValType | None = None,
  3203. ) -> Generator[tuple[KeyType, ValType | None, ValType | None], None, None]:
  3204. """
  3205. Zip two dictionaries together, replacing missing keys with default values.
  3206. Args:
  3207. dict1 (dict): The first dictionary.
  3208. dict2 (dict): The second dictionary.
  3209. d1_default (Any): the default value for the first dictionary
  3210. d2_default (Any): the default value for the second dictionary
  3211. Yields:
  3212. tuple: A tuple containing the key, the value from dict1 (or d1_default if missing),
  3213. and the value from dict2 (or d2_default if missing).
  3214. """
  3215. # Find the union of all keys
  3216. all_keys = OrderedSet(dict1.keys()) | OrderedSet(dict2.keys())
  3217. # Iterate over all keys
  3218. for key in all_keys:
  3219. # Get the values from both dictionaries, or default if missing
  3220. value1 = dict1.get(key)
  3221. value2 = dict2.get(key)
  3222. yield (
  3223. key,
  3224. value1 if value1 is not None else d1_default,
  3225. value2 if value2 is not None else d2_default,
  3226. )
  3227. def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, Any]:
  3228. """
  3229. Ensures the configuration is internally consistent for standalone AOTInductor.
  3230. If `aot_inductor_mode.compile_standalone` is set to True in the provided
  3231. `config_patches` (or falls back to the global config), this function ensures
  3232. that the following configs are also enabled:
  3233. - `aot_inductor.package_cpp_only`
  3234. Args:
  3235. config_patches (dict[str, Any]): A dictionary of user-provided config
  3236. overrides for AOTInductor compilation.
  3237. Returns:
  3238. dict[str, Any]: The possibly-updated `config_patches` dictionary.
  3239. """
  3240. def patch_config(
  3241. config_patches: dict[str, Any], config_name: str, config_value: Any
  3242. ) -> None:
  3243. value = config_patches.get(config_name, getattr(config, config_name))
  3244. if value is None:
  3245. config_patches[config_name] = config_value
  3246. elif not value and value != config_value:
  3247. raise RuntimeError(
  3248. f"Invalid config: {config_name}={config_value} when aot_inductor_mode.compile_standalone is True."
  3249. )
  3250. def force_patch_config(
  3251. config_patches: dict[str, Any], config_name: str, config_value: Any
  3252. ) -> None:
  3253. value = config_patches.get(config_name, getattr(config, config_name))
  3254. if value != config_value:
  3255. log.warning(
  3256. "Overriding: %s=%s when aot_inductor_mode.compile_standalone is True.",
  3257. config_name,
  3258. config_value,
  3259. )
  3260. config_patches[config_name] = config_value
  3261. compile_standalone = config_patches.get(
  3262. "aot_inductor_mode.compile_standalone",
  3263. config.aot_inductor_mode.compile_standalone,
  3264. )
  3265. # Make a copy of the config_patches to avoid modifying the original dictionary, needed for testing
  3266. config_patches = config_patches.copy()
  3267. if compile_standalone:
  3268. # Standlaone AOTInductor means only generate cpp project for building a standalone binary
  3269. patch_config(config_patches, "aot_inductor.package_cpp_only", True)
  3270. # Standlaone AOTInductor needs to embed the kernel code in the binary
  3271. patch_config(config_patches, "aot_inductor.embed_kernel_binary", True)
  3272. # Default to use multi-arch kernel codegen for non-rocm GPU
  3273. patch_config(
  3274. config_patches, "aot_inductor.emit_multi_arch_kernel", not torch.version.hip
  3275. )
  3276. patch_config(
  3277. config_patches, "aot_inductor.model_name_for_generated_files", "aoti_model"
  3278. )
  3279. # TODO: change these two configs to default to None and use patch_config
  3280. force_patch_config(
  3281. config_patches,
  3282. "aot_inductor.link_libtorch",
  3283. config.test_configs.use_libtorch,
  3284. )
  3285. force_patch_config(config_patches, "aot_inductor.dynamic_linkage", False)
  3286. cross_target_platform = config_patches.get(
  3287. "aot_inductor.cross_target_platform",
  3288. config.aot_inductor.cross_target_platform,
  3289. )
  3290. package_constants_in_so = config_patches.get(
  3291. "aot_inductor.package_constants_in_so",
  3292. config.aot_inductor.package_constants_in_so,
  3293. )
  3294. if cross_target_platform == "windows" and package_constants_in_so:
  3295. raise RuntimeError(
  3296. "config.aot_inductor.package_constants_in_so is not supported for windows cross-compilation. "
  3297. "Please use config.aot_inductor.package_constants_on_disk_format = binary_blob."
  3298. )
  3299. return config_patches
  3300. def determine_aoti_mmap_flags(consts_size: int) -> tuple[bool, bool]:
  3301. """
  3302. Decide whether we should mmap weights, and whether to store the weights with .so.
  3303. If force_mmap_weights or package_constants_on_disk_format == "binary_blob" configs are set, respect the config.
  3304. Returns tuple (use_external_weights, use_mmap_weights).
  3305. """
  3306. if (
  3307. config.aot_inductor.force_mmap_weights
  3308. and config.aot_inductor.package_constants_on_disk_format == "binary_blob"
  3309. ):
  3310. raise RuntimeError(
  3311. "config.aot_inductor.package_constants_on_disk_format = binary_blob and "
  3312. "config.aot_inductor.force_mmap_weights cannot both be True."
  3313. )
  3314. if config.aot_inductor.force_mmap_weights:
  3315. if config.aot_inductor.cross_target_platform == "windows":
  3316. raise RuntimeError(
  3317. "when cross_target_platform is windows, use_mmap_weights should not be true."
  3318. )
  3319. use_mmap_weights = True
  3320. use_external_weights = False
  3321. return use_external_weights, use_mmap_weights
  3322. if config.aot_inductor.package_constants_on_disk_format == "binary_blob":
  3323. use_external_weights = True
  3324. use_mmap_weights = False
  3325. return use_external_weights, use_mmap_weights
  3326. if consts_size <= 2_000_000_000:
  3327. return False, False
  3328. use_external_weights = False
  3329. use_mmap_weights = not config.is_fbcode()
  3330. return use_external_weights, use_mmap_weights
  3331. def is_valid_aoti_model_name() -> bool:
  3332. """
  3333. Validates if a model name is suitable for use in code generation.
  3334. """
  3335. from torch._inductor import config
  3336. model_name = config.aot_inductor.model_name_for_generated_files
  3337. if model_name is None:
  3338. return True
  3339. if not isinstance(model_name, str):
  3340. raise ValueError("Invalid AOTI model name: Model name must be a string")
  3341. if model_name == "":
  3342. return True
  3343. # Can only contain alphanumeric characters and underscores
  3344. if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", model_name):
  3345. raise ValueError(
  3346. "Invalid AOTI model name: Model name can only contain letters, numbers, and underscores"
  3347. )
  3348. return True
  3349. def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]:
  3350. if unbacked_only:
  3351. return free_unbacked_symbols(x)
  3352. else:
  3353. return free_symbols(x)
  3354. def python_subprocess_env() -> dict[str, str]:
  3355. """
  3356. Get a base environment for running Python subprocesses.
  3357. """
  3358. env = {
  3359. # Inherit the environment of the current process.
  3360. **os.environ,
  3361. # Set the PYTHONPATH so the subprocess can find torch.
  3362. "PYTHONPATH": os.environ.get(
  3363. "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path)
  3364. ),
  3365. }
  3366. # Set PYTHONHOME for internal builds, to account for builds that bundle the
  3367. # runtime. Otherwise they will use the libraries and headers from the
  3368. # platform runtime instead.
  3369. #
  3370. # This can't be done for external builds. The process can be run from a
  3371. # venv and that won't include Python headers. The process needs to be able
  3372. # to search for and find the platform runtime.
  3373. if config.is_fbcode():
  3374. env["PYTHONHOME"] = sysconfig.get_path("data")
  3375. return env
  3376. @dataclasses.dataclass(frozen=True)
  3377. class CUDAGraphWrapperMetadata:
  3378. """
  3379. Metadata for Customized CUDAGraphWrapper.
  3380. Currently assumes there is 1 dynamo graph and will extend to
  3381. multiple graphs in the future.
  3382. """
  3383. # The number of partitions that are cudagraphable.
  3384. num_partitions: int
  3385. # Index of the current partition.
  3386. partition_index: int
  3387. PartitionFnType = Callable[..., Any]
  3388. CUDAGraphWrapperType = Callable[
  3389. [PartitionFnType, CUDAGraphWrapperMetadata], PartitionFnType
  3390. ]
  3391. # only incremented by user call of mark_step_begin
  3392. class CUDAGraphWrapper:
  3393. wrapper: Optional[CUDAGraphWrapperType] = None
  3394. # A customized partition wrappers from users. Interface should be:
  3395. #
  3396. # def wrapper(fn: PartitionFnType, metadata: CUDAGraphWrapperMetadata) -> PartitionFnType
  3397. #
  3398. # Inductor generates N wrapper functions for N partition functions, and mechanically wrap
  3399. # each partition fn with the generated wrapper function. Users need to handle all details
  3400. # such as static inputs, dynamic shapes, etc.
  3401. # Users could customize the wrapper based on the metadata. One example is to have special
  3402. # handle for the first and last wrapper function.
  3403. #
  3404. # Warning: This API is unstable and may change in the future.
  3405. _unstable_customized_partition_wrapper = CUDAGraphWrapper()
  3406. def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
  3407. _unstable_customized_partition_wrapper.wrapper = wrapper
  3408. def snode_args_kwargs(snode: BaseSchedulerNode) -> tuple[list[Any], dict[str, Any]]:
  3409. args = snode.node.inputs # type: ignore[union-attr]
  3410. args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
  3411. [*args, *snode.node.constant_args], # type: ignore[union-attr]
  3412. snode.node.kwargs, # type: ignore[union-attr]
  3413. )
  3414. kwargs = snode.node.kwargs # type: ignore[union-attr]
  3415. flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
  3416. def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
  3417. return isinstance(x, torch._inductor.ir.IRNode) and not isinstance(
  3418. x, torch._inductor.ir.GeneratorState
  3419. )
  3420. flat_args = [
  3421. torch._inductor.ir.ir_node_to_tensor(a, guard_shape=False)
  3422. if _is_tensor_ir(a)
  3423. else a
  3424. for a in flat_args
  3425. ]
  3426. def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def]
  3427. return torch.empty(size, dtype=dtype, device=device)
  3428. def to_real_tensor(e: Any) -> Any:
  3429. if not isinstance(e, torch.Tensor):
  3430. return e
  3431. out = _tensor(e.size(), e.dtype, e.device)
  3432. return out
  3433. flat_args = [to_real_tensor(a) for a in flat_args]
  3434. args, kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec)
  3435. return args, kwargs
  3436. def is_nonfreeable_buffers(dep: Dep) -> bool:
  3437. from .virtualized import V
  3438. dep_name = dep.name
  3439. # Subgraphs have a prefix for the name, cleanup the prefix
  3440. # before checking for known strings.
  3441. if V.graph.name:
  3442. dep_name = dep_name.removeprefix(V.graph.name + "_")
  3443. return dep_name.startswith(
  3444. ("primals_", "arg", "fwd_rng_state", "bwd_rng_state", "tangents")
  3445. )
  3446. # Make sure to also include your jinja templates within torch_package_data in setup.py, or this function won't be able to find them
  3447. def load_template(name: str, template_dir: Path) -> str:
  3448. """Load a template file and return its content."""
  3449. with open(template_dir / f"{name}.py.jinja") as f:
  3450. return f.read()
  3451. def should_fallback_by_default(node: torch.fx.Node) -> bool:
  3452. """Decide whether fallback for a node. This is only used in inductor lite mode."""
  3453. target = node.target
  3454. assert isinstance(
  3455. target, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
  3456. ), f"Expected OpOverload or HigherOrderOperator, but found {type(target)}"
  3457. if not config.fallback_by_default:
  3458. return False
  3459. # some ops need special handle due to dynamic shapes. we can avoid
  3460. # fallback if they do not impact numerics.
  3461. skip_fallback_due_to_dynamic_shape = OrderedSet(
  3462. [
  3463. torch.ops.aten._assert_scalar.default,
  3464. torch.ops.aten.lift_fresh_copy.default,
  3465. ]
  3466. )
  3467. if target in skip_fallback_due_to_dynamic_shape:
  3468. return False
  3469. # Most hops have registered lowering. We should follow the lowering and not fallback.
  3470. # However, in rare cases, hops may not register lowering, such as
  3471. # torch.ops.higher_order.triton_kernel_wrapper_functional. We should fallback for
  3472. # these hops.
  3473. fallback_hops = OrderedSet(
  3474. [torch.ops.higher_order.triton_kernel_wrapper_functional]
  3475. )
  3476. if isinstance(target, torch._ops.HigherOrderOperator):
  3477. return target in fallback_hops
  3478. return not _needs_inductor_compile(node)
  3479. # Collective operation names for specialized benchmarking
  3480. COLLECTIVE_OPS = OrderedSet(
  3481. [
  3482. "torch.ops._c10d_functional.all_reduce.default",
  3483. "torch.ops._c10d_functional.all_reduce_.default",
  3484. "torch.ops._c10d_functional.all_gather_into_tensor.default",
  3485. "torch.ops._c10d_functional.reduce_scatter_tensor.default",
  3486. "torch.ops._c10d_functional.all_to_all_single.default",
  3487. "torch.ops._c10d_functional_autograd.all_reduce.default",
  3488. "torch.ops._c10d_functional_autograd.all_gather_into_tensor.default",
  3489. "torch.ops._c10d_functional_autograd.reduce_scatter_tensor.default",
  3490. "torch.ops._c10d_functional_autograd.all_to_all_single.default",
  3491. ]
  3492. )
  3493. def is_collective_op(op_name: str) -> bool:
  3494. """Check if an operation is a collective operation."""
  3495. return op_name in COLLECTIVE_OPS
  3496. @lru_cache
  3497. def tlx_only_cuda_options() -> list[str]:
  3498. if config.is_fbcode():
  3499. try:
  3500. from torch._inductor.fb.tlx_templates.registry import tlx_only_cuda_options
  3501. return tlx_only_cuda_options
  3502. except ImportError:
  3503. return []
  3504. else:
  3505. return []
  3506. def _round_up(x: int, y: int) -> int:
  3507. """Round x up to the nearest multiple of y."""
  3508. return ((x + y - 1) // y) * y
  3509. def _infer_scale_swizzle_impl(
  3510. mat_size: tuple[Any, Any],
  3511. scale_size: tuple[Any, ...],
  3512. scale_numel: Any,
  3513. mat_dtype: torch.dtype,
  3514. scale_dtype: torch.dtype,
  3515. eq_fn: Callable[[Any, Any], bool],
  3516. ) -> tuple[Optional[Any], Optional[Any]]:
  3517. """
  3518. Core implementation for scale/swizzle inference.
  3519. """
  3520. from torch.nn.functional import ScalingType, SwizzleType
  3521. # Tensor-wise: single scale for entire tensor
  3522. if eq_fn(scale_numel, 1):
  3523. return ScalingType.TensorWise, SwizzleType.NO_SWIZZLE
  3524. # Row-wise: one scale per row or column
  3525. if len(scale_size) >= 2:
  3526. if (eq_fn(scale_size[0], mat_size[0]) and eq_fn(scale_size[1], 1)) or (
  3527. eq_fn(scale_size[0], 1) and eq_fn(scale_size[1], mat_size[1])
  3528. ):
  3529. return ScalingType.RowWise, SwizzleType.NO_SWIZZLE
  3530. # Block-wise 1x128 / 128x1 (DeepGemm style)
  3531. if (
  3532. eq_fn(scale_size[0], mat_size[0])
  3533. and eq_fn(scale_size[1], ceildiv(mat_size[1], 128))
  3534. ) or (
  3535. eq_fn(scale_size[1], mat_size[1])
  3536. and eq_fn(scale_size[0], ceildiv(mat_size[0], 128))
  3537. ):
  3538. return ScalingType.BlockWise1x128, SwizzleType.NO_SWIZZLE
  3539. # Block-wise 128x128
  3540. if eq_fn(scale_size[0], ceildiv(mat_size[0], 128)) and eq_fn(
  3541. scale_size[1], ceildiv(mat_size[1], 128)
  3542. ):
  3543. return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
  3544. # Adjust for packed FP4 data (2 values per element)
  3545. K_multiplier = 2 if mat_dtype == torch.float4_e2m1fn_x2 else 1
  3546. # NVFP4: BlockWise1x16 with float8_e4m3fn scales
  3547. if mat_dtype == torch.float4_e2m1fn_x2 and scale_dtype == torch.float8_e4m3fn:
  3548. expected_numel_a = _round_up(mat_size[0], 128) * _round_up(
  3549. ceildiv(K_multiplier * mat_size[1], 16), 4
  3550. )
  3551. expected_numel_b = _round_up(mat_size[1], 128) * _round_up(
  3552. ceildiv(K_multiplier * mat_size[0], 16), 4
  3553. )
  3554. if eq_fn(scale_numel, expected_numel_a) or eq_fn(scale_numel, expected_numel_b):
  3555. return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
  3556. # MXFP8: BlockWise1x32 with float8_e8m0fnu scales
  3557. if scale_dtype == torch.float8_e8m0fnu:
  3558. if not torch.version.hip:
  3559. # NVIDIA: uses swizzled 32x4x4 layout
  3560. expected_numel_a = _round_up(mat_size[0], 128) * _round_up(
  3561. ceildiv(K_multiplier * mat_size[1], 32), 4
  3562. )
  3563. expected_numel_b = _round_up(mat_size[1], 128) * _round_up(
  3564. ceildiv(K_multiplier * mat_size[0], 32), 4
  3565. )
  3566. if eq_fn(scale_numel, expected_numel_a) or eq_fn(
  3567. scale_numel, expected_numel_b
  3568. ):
  3569. return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
  3570. else:
  3571. # AMD: no swizzle
  3572. expected_numel_a = ceildiv(mat_size[0], 32) * K_multiplier * mat_size[1]
  3573. expected_numel_b = ceildiv(K_multiplier * mat_size[1], 32) * mat_size[0]
  3574. if eq_fn(scale_numel, expected_numel_a) or eq_fn(
  3575. scale_numel, expected_numel_b
  3576. ):
  3577. return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
  3578. return None, None
  3579. def infer_scale_swizzle(
  3580. mat: torch.Tensor, scale: torch.Tensor
  3581. ) -> tuple[Optional[Any], Optional[Any]]:
  3582. """
  3583. Infer the scaling type and swizzle mode from matrix and scale tensor shapes/dtypes.
  3584. This function determines how scale factors are laid out relative to the matrix:
  3585. - TensorWise: Single scale for entire tensor
  3586. - RowWise: One scale per row
  3587. - BlockWise1x128/128x128: Block-scaled with float32 scales
  3588. - BlockWise1x32: MXFP8 with float8_e8m0fnu scales (swizzled on NVIDIA)
  3589. - BlockWise1x16: NVFP4 with float8_e4m3fn scales (swizzled)
  3590. Args:
  3591. mat: The matrix tensor (FP8 or FP4)
  3592. scale: The scale factor tensor
  3593. Returns:
  3594. Tuple of (ScalingType, SwizzleType) or (None, None) if unrecognized
  3595. """
  3596. return _infer_scale_swizzle_impl(
  3597. mat_size=(mat.shape[0], mat.shape[1]),
  3598. scale_size=tuple(scale.shape),
  3599. scale_numel=scale.numel(),
  3600. mat_dtype=mat.dtype,
  3601. scale_dtype=scale.dtype,
  3602. eq_fn=lambda a, b: a == b,
  3603. )
  3604. def infer_scale_swizzle_ir(
  3605. mat: Buffer,
  3606. scale: Buffer,
  3607. transpose: bool = False,
  3608. ) -> tuple[Optional[Any], Optional[Any]]:
  3609. """
  3610. Infer the scaling type and swizzle mode for IR nodes (used during graph lowering).
  3611. This is the IR-compatible version of infer_scale_swizzle, using symbolic
  3612. size comparisons via V.graph.sizevars.statically_known_equals.
  3613. """
  3614. from torch._inductor.virtualized import V
  3615. mat_size = mat.get_size()
  3616. scale_size = scale.get_size()
  3617. # Handle transposed matrix
  3618. if transpose:
  3619. mat_size = (mat_size[1], mat_size[0])
  3620. # Compute scale numel symbolically
  3621. scale_numel = functools.reduce(operator.mul, scale_size, 1) if scale_size else 1
  3622. def symbolic_eq(a: Any, b: Any) -> bool:
  3623. """Compare values using symbolic equality when possible."""
  3624. return V.graph.sizevars.statically_known_equals(a, b)
  3625. return _infer_scale_swizzle_impl(
  3626. mat_size=(mat_size[0], mat_size[1]) if len(mat_size) >= 2 else (mat_size[0], 1),
  3627. scale_size=tuple(scale_size),
  3628. scale_numel=scale_numel,
  3629. mat_dtype=mat.dtype,
  3630. scale_dtype=scale.dtype,
  3631. eq_fn=symbolic_eq,
  3632. )