| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535 |
- from __future__ import annotations
- import collections
- import contextlib
- import dataclasses
- import enum
- import functools
- import importlib
- import inspect
- import io
- import itertools
- import logging
- import math
- import operator
- import os
- import platform
- import re
- import shutil
- import statistics
- import sys
- import sysconfig
- import tempfile
- import textwrap
- import time
- import unittest
- from collections.abc import (
- Callable,
- Collection,
- Generator,
- Iterator,
- Mapping,
- MutableMapping,
- MutableSet,
- )
- from datetime import datetime
- from functools import lru_cache
- from io import StringIO
- from typing import (
- Any,
- cast,
- Concatenate,
- Generic,
- Literal,
- NamedTuple,
- Optional,
- Protocol,
- TYPE_CHECKING,
- TypeAlias,
- TypeGuard,
- TypeVar,
- Union,
- )
- from typing_extensions import dataclass_transform, ParamSpec, Self
- from unittest import mock
- import sympy
- import torch
- import torch.utils._pytree as pytree
- from torch._inductor.analysis.device_info import datasheet_tops
- from torch._inductor.runtime.hints import DeviceProperties
- from torch.fx.passes.regional_inductor import _needs_inductor_compile
- from torch.utils._dtype_abbrs import dtype_abbrs
- from torch.utils._ordered_set import OrderedSet
- from torch.utils._pytree import tree_flatten, tree_map_only
- OPTIMUS_EXCLUDE_POST_GRAD = [
- "activation_quantization_aten_pass",
- "inductor_autotune_lookup_table",
- ]
- from torch.fx.experimental.symbolic_shapes import (
- free_symbols,
- free_unbacked_symbols,
- IterateExprs,
- ShapeEnv,
- )
- if TYPE_CHECKING:
- from collections.abc import Iterable, Sequence, ValuesView
- from pathlib import Path
- from torch import SymBool, SymFloat, SymInt
- from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
- from torch.fx import GraphModule
- from torch.fx.node import Node
- from torch.nn.functional import ScalingType # type: ignore[attr-defined]
- from .codegen.common import WorkspaceArg
- from .codegen.wrapper import PythonWrapperCodegen
- from .dependencies import Dep
- from .graph import GraphLowering
- from .ir import Buffer, ExternKernel, IRNode, Layout, Operation, ReinterpretView
- from .output_code import CompiledFxGraph
- from .scheduler import BaseSchedulerNode, SchedulerBuffer
- GPU_TYPES = ["cuda", "mps", "xpu", "mtia"]
- T = TypeVar("T")
- # defines here before import torch._dynamo is for avoiding circular import
- # when get_gpu_type is imported from dynamo
- @functools.cache
- def get_gpu_type() -> str:
- avail_gpus = [x for x in GPU_TYPES if getattr(torch, x).is_available()]
- assert len(avail_gpus) <= 1
- gpu_type = "cuda" if len(avail_gpus) == 0 else avail_gpus.pop()
- return gpu_type
- from torch._dynamo.device_interface import get_interface_for_device
- from torch._dynamo.utils import detect_fake_mode
- from torch.autograd import DeviceType
- from torch.autograd.profiler_util import EventList
- from torch.fx.passes.graph_transform_observer import GraphTransformObserver
- from torch.fx.passes.shape_prop import ShapeProp
- from torch.utils._sympy.functions import (
- CeilDiv,
- CleanDiv,
- FloorDiv,
- Identity,
- ModularIndexing,
- )
- from torch.utils._sympy.symbol import make_symbol, SymT
- from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges
- from . import config
- from .runtime.runtime_utils import ceildiv as runtime_ceildiv
- _IS_WINDOWS = sys.platform == "win32"
- log = logging.getLogger(__name__)
- _T = TypeVar("_T")
- VarRanges = dict[sympy.Expr, sympy.Expr]
- InputType = Optional[Union[torch.Tensor, int, torch.SymInt]]
- XPU_KERNEL_FORMAT = (
- "spv" if _IS_WINDOWS else os.getenv("TORCHINDUCTOR_XPU_KERNEL_FORMAT", "zebin")
- )
- GPU_KERNEL_BIN_EXTS = {
- "cuda": ".cubin",
- "hip": ".hsaco",
- "xpu": f".{XPU_KERNEL_FORMAT}",
- }
- GPU_ALIGN_BYTES = 16
- ALIGNMENT = 16
- TMA_ALIGNMENT = 16
- TMA_DESCRIPTOR_SIZE = 128
- # PyTorch dtypes with valid CUtensorMapDataType mappings.
- # Ref: triton/backends/nvidia/include/cuda.h (CUtensorMapDataType enum)
- # triton/_internal_testing.py (tma_dtypes test list)
- _TMA_SUPPORTED_DTYPES: OrderedSet[torch.dtype] = OrderedSet(
- [
- torch.uint8,
- torch.int8,
- torch.uint16,
- torch.int16,
- torch.uint32,
- torch.int32,
- torch.int64,
- torch.float16,
- torch.bfloat16,
- torch.float32,
- torch.float64,
- torch.float8_e4m3fn,
- torch.float8_e5m2,
- torch.float8_e4m3fnuz,
- torch.float8_e5m2fnuz,
- ]
- )
- ALIGN_BYTES = 64
- assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2"
- def _align(nbytes: int) -> int:
- """Round up to the nearest multiple of ALIGN_BYTES"""
- return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES
- def _is_aligned(v: sympy.Expr) -> bool:
- """v can be statically proven to be a multiple of ALIGN_BYTES"""
- if isinstance(v, (sympy.Add, sympy.Max)):
- return all(map(_is_aligned, v.args))
- return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES
- class align(sympy.Function):
- """Symbolically round up to the nearest multiple of ALIGN_BYTES"""
- nargs = (1,)
- is_integer = True
- @classmethod
- def eval(cls, value: sympy.Expr) -> Optional[sympy.Expr]:
- if isinstance(value, (int, sympy.Integer)):
- return _align(int(value))
- if _is_aligned(value):
- return value
- @dataclasses.dataclass(frozen=True)
- class GraphPartitionMap:
- """
- Mapping from the partition info (e.g., input/output) to the graph info
- """
- # a unique id of graph partition
- id: int
- # map partition input/output indices to graph input/output indices. None indicates
- # a partition input/output is not a graph input/output.
- input_index_mapping: list[Optional[int]]
- output_index_mapping: list[Optional[int]]
- # name of constants read/written by the graph partition
- constant_names: list[str]
- def fp8_bench(fn: Callable[[], Any], warmup: int = 25, rep: int = 100) -> float:
- """
- Returns benchmark results by examining torch profiler events.
- This could be more accurate as it doesn't count CPU side overhead.
- However, this also requires manually excluding irrelevant event, e.g.
- vectorized_elementwise_kernel which is used to fill L2 cache,
- various CUDA events, etc, so could also be fragile.
- """
- fn()
- torch.cuda.synchronize()
- cache = torch.empty(int(256e6 // 4), dtype=torch.float16, device="cuda")
- # Estimate the runtime of the function
- start_event = torch.cuda.Event(enable_timing=True)
- end_event = torch.cuda.Event(enable_timing=True)
- start_event.record()
- for _ in range(5):
- cache.zero_()
- fn()
- end_event.record()
- torch.cuda.synchronize()
- estimate_ms = start_event.elapsed_time(end_event) / 5
- # compute number of warmup and repeat
- n_warmup = max(1, int(warmup / estimate_ms))
- n_repeat = max(1, int(rep / estimate_ms))
- # Warm-up
- for _ in range(n_warmup):
- fn()
- start_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
- end_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)]
- with torch.profiler.profile(
- activities=[
- torch.profiler.ProfilerActivity.CUDA,
- ]
- ) as p:
- torch.cuda.synchronize()
- for i in range(n_repeat):
- cache.zero_()
- start_event[i].record()
- with torch.cuda.nvtx.range("RunCudaModule"):
- fn()
- end_event[i].record()
- torch.cuda.synchronize()
- times = torch.tensor(
- [s.elapsed_time(e) for s, e in zip(start_event, end_event)]
- )
- res = torch.mean(times).item()
- log.debug("raw events")
- log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1))
- filtered_events = EventList(
- [
- event
- for event in p.events()
- if (
- event.device_type == DeviceType.CUDA
- and re.match(r"fused_abs_max_\d", event.name) is not None
- )
- ]
- )
- if filtered_events:
- res -= (
- statistics.mean(event.device_time_total for event in filtered_events)
- / 1000.0
- )
- log.debug("profiling results: %s ms", res)
- return res
- def do_bench_using_profiling(
- fn: Callable[[], Any],
- warmup: int = 25,
- rep: int = 100,
- is_vetted_benchmarking: bool = False,
- ) -> float:
- # We did't use decorator may_distort_benchmarking_result directly since that
- # requires us to import torch._inductor.runtime.benchmarking into global scope.
- # Importing torch._inductor.runtime.benchmarking will cause cuda initialization
- # (because of calling torch.cuda.available in global scope)
- # which cause failure in vllm when it create child processes. Check log:
- # https://gist.github.com/shunting314/c194e147bf981e58df095c14874dd65a
- #
- # Another way to solve the issue is to just move do_bench_using_profiling
- # to torch._inductor.runtime.benchmarking and change all the call site.
- # But that's not trivial due to so many call sites in and out of pytorch.
- from torch._inductor.runtime.benchmarking import may_distort_benchmarking_result
- return may_distort_benchmarking_result(_do_bench_using_profiling)(
- fn, warmup, rep, is_vetted_benchmarking
- )
- def _do_bench_using_profiling(
- fn: Callable[[], Any],
- warmup: int = 25,
- rep: int = 100,
- is_vetted_benchmarking: bool = False,
- ) -> float:
- """
- Returns benchmark results by examining torch profiler events.
- This could be more accurate as it doesn't count CPU side overhead.
- However, this also requires manually excluding irrelevant event, e.g.
- vectorized_elementwise_kernel which is used to fill L2 cache,
- various CUDA events, etc, so could also be fragile.
- """
- if not is_vetted_benchmarking:
- from torch._inductor.runtime.benchmarking import may_ban_benchmarking
- may_ban_benchmarking()
- device_type = get_gpu_type()
- device_type_upper = device_type.upper()
- device_interface = get_interface_for_device(device_type)
- fn()
- device_interface.synchronize()
- cache = torch.empty(int(256e6 // 4), dtype=torch.int, device=device_type)
- # Estimate the runtime of the function
- start_event = device_interface.Event(enable_timing=True)
- end_event = device_interface.Event(enable_timing=True)
- start_event.record()
- for _ in range(5):
- cache.zero_()
- fn()
- end_event.record()
- device_interface.synchronize()
- estimate_ms = start_event.elapsed_time(end_event) / 5
- # compute number of warmup and repeat
- n_warmup = max(1, int(warmup / estimate_ms))
- n_repeat = max(1, int(rep / estimate_ms))
- # Warm-up
- for _ in range(n_warmup):
- fn()
- device_interface.synchronize()
- with torch.profiler.profile(
- activities=[
- getattr(torch.profiler.ProfilerActivity, device_type_upper),
- ]
- ) as p:
- # Benchmark
- for _ in range(n_repeat):
- # we clear the L2 cache before each run
- cache.zero_()
- # record time of `fn`
- fn()
- # Record clocks
- device_interface.synchronize()
- log.debug("raw events")
- log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1))
- filtered_events = EventList(
- [
- event
- for event in p.events()
- if event.device_type == getattr(DeviceType, device_type_upper)
- and event.name != "Context Sync"
- ]
- )
- if len(filtered_events) % n_repeat != 0:
- raise RuntimeError(
- "Failed to divide all profiling events into #repeat groups. "
- "#%s events: %d, #repeats: %s",
- device_type,
- len(filtered_events),
- n_repeat,
- )
- num_event_per_group = len(filtered_events) / n_repeat
- actual_events = EventList(
- [
- event
- for i, event in enumerate(filtered_events)
- if i % num_event_per_group != 0
- ]
- )
- actual_events._build_tree()
- actual_events = actual_events.key_averages()
- log.debug("profiling time breakdown")
- log.debug(actual_events.table(row_limit=-1))
- res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat
- log.debug("profiling results: %s ms", res)
- return res
- @functools.cache
- def has_torchvision_roi_align() -> bool:
- try:
- from torchvision.ops import roi_align # noqa: F401
- torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta")
- return roi_align is not None and hasattr(
- getattr(torch.ops, "torchvision", None), "roi_align"
- )
- except ImportError:
- return False
- except RuntimeError as e:
- assert "torchvision::nms does not exist" in str(e)
- return False
- def decode_device(device: Union[Optional[torch.device], str]) -> torch.device:
- if device is None:
- return torch.tensor(0.0).device # default device
- if isinstance(device, str):
- device = torch.device(device)
- if device.type not in ("cpu", "meta") and device.index is None:
- device_interface = get_interface_for_device(device.type)
- return torch.device(device.type, index=device_interface.Worker.current_device())
- return device
- def sympy_product(it: Iterable[sympy.Expr]) -> sympy.Expr:
- return functools.reduce(operator.mul, it, sympy.S.One)
- def sympy_dot(seq1: Sequence[sympy.Expr], seq2: Sequence[sympy.Expr]) -> sympy.Expr:
- assert len(seq1) == len(seq2)
- return sympy.expand(sum(a * b for a, b in zip(seq1, seq2)))
- def unique(it: Iterable[_T]) -> ValuesView[_T]:
- return {id(x): x for x in it}.values()
- def ceildiv(
- number: Union[int, sympy.Expr], denom: Union[int, sympy.Expr]
- ) -> Union[int, sympy.Expr]:
- if isinstance(number, sympy.Expr) or isinstance(denom, sympy.Expr):
- return CeilDiv(sympy.sympify(number), sympy.sympify(denom))
- # TODO: There is a bug in a call to this function, to repro:
- # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
- # --amp --only YituTechConvBert --dynamic-shapes
- assert isinstance(number, int) and isinstance(denom, int), (
- f"{number}: {type(number)}, {denom}: {type(denom)}"
- )
- return runtime_ceildiv(number, denom)
- def _type_of(key: Optional[torch.dtype]) -> str:
- # Use the function here to get rid of dependencies on the Triton during the codegen.
- # Refer to Triton implementation here:
- # https://github.com/triton-lang/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238
- # `None` is nullptr. Implicitly convert to *i8.
- if key is None:
- return "*i8"
- dtype_str = str(key).split(".")[-1]
- tys = {
- "bool": "i1",
- "float8e4nv": "fp8e4nv",
- "float8e5": "fp8e5",
- "float8e4b15": "fp8e4b15",
- "float8e4b15x4": "fp8e4b15x4",
- "float8_e4m3fn": "fp8e4nv",
- "float8_e5m2": "fp8e5",
- # TODO: remove when support is added in triton
- # https://github.com/triton-lang/triton/issues/6054
- "float8_e8m0fnu": "u8",
- "float4_e2m1fn_x2": "u8",
- "float16": "fp16",
- "bfloat16": "bf16",
- "float32": "fp32",
- "float64": "fp64",
- "int8": "i8",
- "int16": "i16",
- "int32": "i32",
- "int64": "i64",
- "uint8": "u8",
- "uint16": "u16",
- "uint32": "u32",
- "uint64": "u64",
- }
- # reinterpret can create triton type
- tys.update({v: v for v in list(tys.values())})
- return key if isinstance(key, str) else f"*{tys[dtype_str]}"
- def convert_shape_to_inductor(
- lst: Iterable[Union[int, torch.SymInt]],
- ) -> list[sympy.Expr]:
- """
- Gets the shape and stride of a tensor. For non-symbolic tensors, this is
- trivial. But for symbolic tensors, we need to map from SymIntNode into
- sympy.Expr.
- """
- return [sympy.sympify(i) for i in lst]
- def convert_symint_to_expr(val: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]:
- """
- Convert SymInt to sympy.Expr, leave int as is.
- Unlike sympy.sympify() which converts int to sympy.Integer,
- this function preserves int as int and only converts SymInt to Expr.
- """
- if isinstance(val, torch.SymInt):
- return val.node.expr
- return val
- def convert_to_symint(i: Union[int, sympy.Expr]) -> Union[int, torch.SymInt]:
- """
- Like convert_shape_to_symint, but operates on a single expression.
- """
- from .virtualized import V
- return (
- i
- if isinstance(i, int)
- else (
- int(i)
- if isinstance(i, sympy.Integer)
- else V.graph.sizevars.shape_env.create_symintnode(i, hint=None)
- )
- )
- def convert_shape_to_symint(
- lst: Iterable[Union[int, sympy.Expr]],
- ) -> list[Union[int, torch.SymInt]]:
- """
- Takes a list of shapes from Inductor and converts them into symints (or just
- ints if all shapes are static).
- """
- return [convert_to_symint(i) for i in lst]
- def is_view(op: torch._ops.OpOverload) -> bool:
- """
- Does this op overload have aliasing
- """
- return any(a.alias_info is not None for a in op._schema.arguments)
- def is_pointwise_use(
- use: Node,
- is_pointwise_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False,
- ) -> bool:
- """
- Do all uses of this op have torch.Tag.pointwise or return True for optional `is_pointwise_fn`
- Uses in views ops will follow the views uses
- """
- if use.op != "call_function":
- return False
- if not (
- isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem
- ):
- return False
- target = cast(torch._ops.OpOverload, use.target)
- if target is operator.getitem or is_view(target):
- return all(is_pointwise_use(u, is_pointwise_fn) for u in use.users)
- return torch.Tag.pointwise in target.tags or is_pointwise_fn(target)
- def gen_gm_and_inputs(
- target: Any, args: list[Any], kwargs: dict[str, Any]
- ) -> tuple[GraphModule, list[torch.Tensor]]:
- g = torch.fx.Graph()
- graph_args: list[torch.Tensor] = []
- def add_tensor_arg(arg: torch.Tensor) -> Node:
- graph_args.append(arg)
- return g.placeholder(f"arg{len(graph_args)}")
- node = g.call_function(
- target, *tree_map_only(torch.Tensor, add_tensor_arg, (args, kwargs))
- )
- if (
- len(target._schema.returns) == 1
- and str(target._schema.returns[0].type) == "Tensor"
- ):
- node = (node,) # type: ignore[assignment]
- g.output(node)
- gm = torch.fx.GraphModule({}, g)
- return gm, graph_args
- def synchronize(device: str = "cuda") -> None:
- if device == "cpu":
- return
- device_interface = get_interface_for_device(device)
- if device_interface.is_available():
- device_interface.synchronize()
- def timed(
- model: Callable[..., Any],
- example_inputs: Sequence[Any],
- times: int = 1,
- device: str = "cuda",
- ) -> float:
- synchronize(device)
- torch.manual_seed(1337)
- t0 = time.perf_counter()
- for _ in range(times):
- result = model(*example_inputs)
- synchronize(device)
- t1 = time.perf_counter()
- # GC the result after timing
- assert result is not None # type: ignore[possibly-undefined]
- return t1 - t0
- def print_performance(
- model: Callable[..., Any],
- example_inputs: Sequence[Any] = (),
- times: int = 10,
- repeat: int = 10,
- baseline: float = 1.0,
- device: str = "cuda",
- ) -> float:
- timings = torch.tensor(
- [timed(model, example_inputs, times, device) for _ in range(repeat)]
- )
- took = torch.median(timings) / times
- print(f"{took / baseline:.6f}")
- return took.item()
- def precompute_method(obj: Any, method: str) -> None:
- """Replace obj.method() with a new method that returns a precomputed constant."""
- result = getattr(obj, method)()
- setattr(obj, method, lambda: result)
- def precompute_methods(obj: Any, methods: list[str]) -> None:
- """Replace methods with new methods that returns a precomputed constants."""
- for method in methods:
- precompute_method(obj, method)
- def cmp(a: int, b: int) -> int:
- return int(a > b) - int(a < b)
- def pad_listlike(x: Union[int, Sequence[int]], size: int) -> Sequence[int]:
- if isinstance(x, int):
- return [x] * size
- if len(x) == 1:
- return type(x)([x[0]]) * size # type: ignore[call-arg, operator, return-value]
- return x
- # Used to ensure that iterating over a set is deterministic
- def tuple_sorted(x: tuple[_T, ...]) -> list[_T]:
- if len(x) == 0:
- return []
- def sort_func(elem: _T) -> str:
- if isinstance(elem, str):
- return elem
- from .scheduler import BaseSchedulerNode
- assert isinstance(elem, BaseSchedulerNode)
- return elem.get_name()
- return sorted(x, key=sort_func)
- P = ParamSpec("P")
- RV = TypeVar("RV", covariant=True)
- FN_TYPE = Callable[Concatenate[Any, P], RV]
- class CachedMethod(Protocol, Generic[P, RV]):
- @staticmethod
- def clear_cache(cache: Any) -> None: ...
- def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: ...
- # See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
- def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]:
- name = fn.__name__
- key = f"__{name}_cache"
- # wrapper is likely on the hot path, compile a specialized version of it
- ctx = {"fn": fn}
- exec(
- f"""\
- def {name}_cache_on_self(self):
- try:
- return self.{key}
- except AttributeError:
- pass
- rv = fn(self)
- object.__setattr__(self, "{key}", rv)
- return rv
- """.lstrip(),
- ctx,
- )
- wrapper = functools.wraps(fn)(ctx[f"{name}_cache_on_self"])
- def clear_cache(self: Any) -> None:
- if hasattr(self, key):
- delattr(self, key)
- wrapper.clear_cache = clear_cache # type: ignore[attr-defined]
- return wrapper # type: ignore[return-value]
- def cache_property_on_self(
- fn: Callable[Concatenate[Any, P], RV],
- ) -> CachedMethod[P, RV]:
- """
- Variant of cache_on_self for properties. The only difference is the type signature.
- """
- return cache_on_self(fn)
- def cache_on_self_and_args(
- class_name: str,
- ) -> Callable[[FN_TYPE[P, RV]], FN_TYPE[P, RV]]:
- # include both class_name and fn_name in the key to support `super().fn(self, **args, **kwargs)` calls.
- def wrapper(
- fn: FN_TYPE[P, RV],
- ) -> FN_TYPE[P, RV]:
- key = f"__{class_name}_{fn.__name__}_cache"
- # wrapper is likely on the hot path, compile a specialized version of it
- ctx = {"fn": fn}
- exec(
- f"""\
- def inner(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV:
- args_kwargs = (args, tuple(sorted(kwargs.items())))
- if not hasattr(self, "{key}"):
- object.__setattr__(self, "{key}", {{}})
- cache = self.{key}
- try:
- return cache[args_kwargs]
- except KeyError:
- pass
- rv = fn(self, *args, **kwargs)
- cache[args_kwargs] = rv
- return rv
- """.lstrip(),
- ctx,
- )
- inner = functools.wraps(fn)(ctx["inner"])
- def clear_cache(self: Any) -> None:
- if hasattr(self, key):
- delattr(self, key)
- inner.clear_cache = clear_cache # type: ignore[attr-defined]
- return inner
- return wrapper
- def aggregate_origins(
- node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
- ) -> OrderedSet[Node]:
- from . import ir
- if isinstance(node_schedule, list):
- return functools.reduce(
- operator.or_,
- [
- # pyrefly: ignore [missing-attribute]
- node.node.origins
- for node in node_schedule
- if hasattr(node, "node") and node.node
- ],
- OrderedSet(),
- )
- elif isinstance(node_schedule, ir.ExternKernel):
- return node_schedule.origins
- else:
- return OrderedSet()
- def get_fused_kernel_name(
- node_schedule: Sequence[BaseSchedulerNode],
- descriptive_names: Literal[True, "torch", "original_aten", "inductor_node"],
- ) -> str:
- all_origins = aggregate_origins(node_schedule)
- if descriptive_names == "original_aten":
- def get_origin_meta_str(origin):
- original_aten = origin.meta["original_aten"]
- key = ""
- if isinstance(original_aten, torch._ops.OpOverload):
- key = original_aten._overloadpacket.__name__
- elif isinstance(original_aten, torch._ops.HigherOrderOperator):
- key = str(original_aten.name())
- return key
- # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions)
- sources = [
- get_origin_meta_str(origin)
- for origin in all_origins
- if origin.op == "call_function"
- and "original_aten" in origin.meta
- and origin.meta["original_aten"] is not None
- ]
- sources = sorted(OrderedSet(sources))
- elif descriptive_names == "torch":
- # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph)
- sources = []
- for origin in all_origins:
- if origin.op == "call_function":
- source_fn = None
- suffix = ""
- if "source_fn_stack" in origin.meta:
- source_fn = origin.meta["source_fn_stack"][-1]
- elif "fwd_source_fn_stack" in origin.meta:
- # backward nodes have "fwd_source_fn_stack" instead
- source_fn = origin.meta["fwd_source_fn_stack"][-1]
- suffix = "backward"
- if not source_fn:
- continue
- if isinstance(source_fn[1], str):
- sources.append(source_fn[1] + suffix)
- else:
- sources.append(source_fn[1].__name__ + suffix)
- sources = sorted(OrderedSet(sources))
- elif descriptive_names == "inductor_node":
- sources = [
- origin.name for origin in all_origins if origin.op == "call_function"
- ]
- else:
- raise NotImplementedError
- return "_".join(["fused"] + sources)
- def get_kernel_metadata(
- node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
- wrapper: PythonWrapperCodegen,
- ) -> tuple[str, str]:
- """
- Retrieves metadata information for a kernel.
- Args:
- node_schedule (Union[Sequence[BaseSchedulerNode], ExternKernel]):
- Either a sequence of BaseSchedulerNode objects or an ExternKernel instance.
- wrapper (PythonWrapperCodegen):
- An instance of PythonWrapperCodegen, used to define the code comment format.
- Returns:
- tuple[str, str]:
- A tuple containing two strings:
- - The first string represents the kernel's metadata.
- - The second string represent the kernel's detailed metadata.
- """
- all_origins = aggregate_origins(node_schedule)
- inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"]
- from_node_dict = collections.defaultdict(list)
- original_aten_dict = collections.defaultdict(list)
- # Attempt to sort `inductor_nodes` topologically. Note that the case
- # where `inductor_nodes` contains nodes from multiple graph instances
- # is not supported. An example of this is conditional statements.
- single_graph = None
- if inductor_nodes:
- unique_graphs = OrderedSet(n.graph for n in inductor_nodes)
- if len(unique_graphs) == 1:
- single_graph = inductor_nodes[0].graph
- # create a map of idx -> node and cache it
- if not hasattr(single_graph, "_inductor_kernel_metadata_node_to_idx_map"):
- node_to_idx_map = {n: idx for idx, n in enumerate(single_graph.nodes)}
- single_graph._inductor_kernel_metadata_node_to_idx_map = node_to_idx_map # type: ignore[attr-defined]
- inductor_nodes.sort(
- key=lambda n: single_graph._inductor_kernel_metadata_node_to_idx_map[n] # type: ignore[attr-defined]
- )
- for node in inductor_nodes:
- if "original_aten" in node.meta and node.meta["original_aten"] is not None:
- original_aten = node.meta["original_aten"]
- key = None
- if isinstance(original_aten, torch._ops.OpOverload):
- key = str(original_aten._overloadpacket)
- elif isinstance(original_aten, torch._ops.HigherOrderOperator):
- key = str(original_aten.name())
- if key:
- original_aten_dict[key].append(node.name)
- if "from_node" in node.meta:
- key = node.meta["from_node"][0].name
- from_node_dict[key].append(node.name)
- elif node.meta.get("partitioner_tag") == "is_backward":
- # backward nodes currently don't have a "from node"
- from_node_dict[node.name].append(node.name)
- sort_str = "Topologically Sorted" if single_graph is not None else "Unsorted"
- metadata = (
- f"{wrapper.comment} {sort_str} Source Nodes: [{', '.join(from_node_dict.keys())}], "
- f"Original ATen: [{', '.join(original_aten_dict.keys())}]"
- )
- # trace back to original node here
- detailed_metadata = [f"{wrapper.comment} Source node to ATen node mapping:"]
- for original_node, nodes in sorted(from_node_dict.items()):
- detailed_metadata.append(
- f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}"
- )
- # print the aot_autograd graph fragment
- if single_graph is not None:
- from . import ir
- detailed_metadata.append(f"{wrapper.comment} Graph fragment:")
- all_reads: OrderedSet[str] = OrderedSet()
- all_writes: list[str] = []
- if not isinstance(node_schedule, ir.ExternKernel):
- from .virtualized import V
- def get_buffer_info(
- buffer: Union[ir.TensorBox, ir.Buffer, ir.TorchBindObject], rw_name: str
- ) -> tuple[str, ir.Layout | None]:
- if isinstance(buffer, ir.TensorBox) and isinstance(
- buffer.data, ir.StorageBox
- ):
- origin_node = buffer.data.data.origin_node
- else:
- origin_node = buffer.origin_node
- if origin_node is None:
- # use the read/write name if no origin node is found
- name = rw_name
- else:
- name = origin_node.name
- try:
- layout = buffer.get_layout()
- except NotImplementedError:
- layout = None
- return name, layout
- def stringify_shape(shape: Iterable[int]) -> str:
- return f"[{', '.join([str(x) for x in shape])}]"
- def stringfy_layout(layout: ir.Layout | None) -> str:
- if layout is None:
- return ""
- shape_annotation = f"{stringify_shape(layout.size)}"
- stride_annotation = f"{stringify_shape(layout.stride)}"
- device_annotation = f"{layout.device}"
- return (
- f'"{dtype_abbrs[layout.dtype]}{shape_annotation}'
- f'{stride_annotation}{device_annotation}"'
- )
- for n in node_schedule:
- if not hasattr(n, "read_writes") or n.read_writes is None:
- continue
- if hasattr(n.read_writes, "reads") and n.read_writes.reads is not None:
- for r in n.read_writes.reads:
- # Remove the dupricated inputs
- if r.name in all_reads:
- continue
- all_reads.add(r.name)
- buffer = V.graph.try_get_buffer(r.name)
- if buffer is None:
- continue
- input_name, layout = get_buffer_info(buffer, r.name)
- detailed_metadata.append(
- f"{wrapper.comment} %{input_name} : Tensor "
- f"{stringfy_layout(layout)} = PlaceHolder[target={input_name}]"
- )
- if (
- hasattr(n.read_writes, "writes")
- and n.read_writes.writes is not None
- ):
- for w in n.read_writes.writes:
- buffer = V.graph.try_get_buffer(w.name)
- if buffer is None:
- continue
- output_name, _ = get_buffer_info(buffer, w.name)
- all_writes.append("%" + output_name)
- for node in inductor_nodes:
- detailed_metadata.append(
- f"{wrapper.comment} {node.format_node(include_tensor_metadata=True)}"
- )
- detailed_metadata.append(f"{wrapper.comment} return {','.join(all_writes)}")
- return metadata, "\n".join(detailed_metadata)
- def dominated_nodes(
- initial_queue: Iterable[torch.fx.Node],
- skip_filter: Optional[Callable[[Any], bool]] = None,
- ) -> OrderedSet[torch.fx.Node]:
- """Returns the set of nodes whose values depend on those within initial_queue"""
- initial_queue = list(initial_queue)
- dominated_set = OrderedSet(initial_queue)
- while initial_queue:
- node = initial_queue.pop()
- for user in node.users:
- if skip_filter and skip_filter(user):
- continue
- if user not in dominated_set:
- dominated_set.add(user)
- initial_queue.append(user)
- return dominated_set
- def gather_origins(
- args: Sequence[IRNode], kwargs: dict[str, IRNode]
- ) -> OrderedSet[torch.fx.Node]:
- from . import ir
- def is_unrealized_node(n: IRNode) -> bool:
- if isinstance(n, ir.TensorBox):
- return is_unrealized_node(n.data)
- if isinstance(n, ir.StorageBox):
- return is_unrealized_node(n.data)
- return isinstance(n, ir.IRNode) and not isinstance(
- n,
- (
- ir.ComputedBuffer,
- ir.InputsKernel,
- ir.InputBuffer,
- ir.TemplateBuffer,
- ),
- )
- # kwargs and args may include a container of node, for example torch.cat([t1, t2])
- # flatten them before search the unrealized nodes
- kwargs_flatten, _ = tree_flatten(kwargs)
- kwargs_origins = [val.origins for val in kwargs_flatten if is_unrealized_node(val)]
- args_flatten, _ = tree_flatten(args)
- args_origins = [val.origins for val in args_flatten if is_unrealized_node(val)]
- return OrderedSet(itertools.chain(*args_origins, *kwargs_origins))
- def sympy_str(expr: sympy.Expr) -> str:
- """
- Normal sympy str is very slow, this is a lot faster. The result are
- somewhat worse, as it doesn't do as much simplification. So don't
- use this for final codegen.
- """
- def is_neg_lead(expr: sympy.Expr) -> bool:
- return (
- isinstance(expr, sympy.Mul) and len(expr.args) == 2 and expr.args[0] == -1
- )
- def sympy_str_add(expr: sympy.Expr) -> str:
- if isinstance(expr, sympy.Add):
- # Special case 'a - b'. Note that 'a - b - c' will still appear as
- # 'a + -1 * b + -1 * c'.
- if len(expr.args) == 2 and is_neg_lead(expr.args[1]):
- return f"{sympy_str_mul(expr.args[0])} - {sympy_str_mul(expr.args[1].args[1])}"
- else:
- return " + ".join(map(sympy_str_mul, expr.args))
- else:
- return sympy_str_mul(expr)
- def sympy_str_mul(expr: sympy.Expr) -> str:
- if isinstance(expr, sympy.Mul):
- if is_neg_lead(expr):
- # Special case '-a'. Note that 'a * -b' will still appear as
- # '-1 * a * b'.
- return f"-{sympy_str_atom(expr.args[1])}"
- else:
- return " * ".join(map(sympy_str_atom, expr.args))
- else:
- return sympy_str_atom(expr)
- def sympy_str_atom(expr: sympy.Expr) -> str:
- if isinstance(expr, sympy.Symbol):
- return expr.name
- elif isinstance(expr, (sympy.Add, sympy.Mul)):
- return f"({sympy_str_add(expr)})"
- elif isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv, Identity)):
- return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})"
- else:
- return str(expr)
- return sympy_str_add(expr)
- def get_bounds_index_expr(index: sympy.Expr) -> ValueRanges[Any]:
- from .virtualized import V
- # If this expression does not come from an FX node, we compute its bounds
- if (
- config.compute_all_bounds
- and (fx_node := getattr(V.interpreter, "current_node", None))
- and fx_node.target != "index_expr"
- ):
- return bound_sympy(index)
- else:
- return ValueRanges.unknown()
- def prefix_is_reduction(prefix: str) -> bool:
- return prefix[0] == "r"
- def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol:
- """
- Used to generate an integer-nonnegative symbol.
- """
- # This should never be used for creating shape/stride symbols, as those
- # should all be allocated before Inductor.
- assert prefix != SymT.SIZE
- # NOTE: shape symbols are positive (> 0), but index variables are only
- # non-negative (>= 0).
- return make_symbol(prefix, idx, integer=True, nonnegative=True)
- def generate_assert(check: bool) -> bool:
- return (check or config.debug_index_asserts) and config.assert_indirect_indexing
- def sympy_index_symbol(name: str) -> sympy.Symbol:
- """
- Used to generate an integer-nonnegative symbol.
- """
- # This should never be used for creating shape/stride symbols, as those
- # should all be allocated before Inductor.
- assert name[0] != "s"
- # NOTE: shape symbols are positive (> 0), but index variables are only
- # non-negative (>= 0).
- return sympy.Symbol(name, integer=True, nonnegative=True)
- def sympy_subs(expr: sympy.Expr, replacements: dict[sympy.Expr, Any]) -> sympy.Expr:
- """
- When the passed replacement symbol v is a string, it is converted to a symbol with name v that
- have the same replaced expression integer and nonnegative properties.
- """
- def to_symbol(
- replaced: sympy.Expr, replacement: Union[sympy.Expr, str]
- ) -> sympy.Symbol:
- assert isinstance(replaced, sympy.Expr)
- if isinstance(replacement, str):
- return sympy.Symbol(
- replacement,
- integer=replaced.is_integer, # type: ignore[attr-defined]
- nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined]
- )
- else:
- return replacement
- # xreplace is faster than subs, but is way more picky
- return sympy.sympify(expr).xreplace(
- {k: to_symbol(k, v) for k, v in replacements.items()}
- )
- def is_symbolic(a: Any) -> TypeGuard[Union[torch.SymInt, torch.Tensor]]:
- return isinstance(a, torch.SymInt) or (
- isinstance(a, torch.Tensor) and a._has_symbolic_sizes_strides
- )
- def any_is_symbolic(*args: Any) -> bool:
- return any(is_symbolic(a) for a in args)
- # Ops that are fundamentally incompatible with CUDA graph capture
- # (e.g., CPU synchronization, dynamic memory allocation, etc.)
- FORBIDDEN_CUDAGRAPH_OPS = frozenset(
- [
- "aten._fused_moving_avg_obs_fq_helper.default",
- "aten._fused_moving_avg_obs_fq_helper_functional.default",
- "fbgemm.dense_to_jagged.default",
- "fbgemm.jagged_to_padded_dense.default",
- "run_and_save_rng_state",
- "run_with_rng_state",
- "aten._local_scalar_dense",
- # Technically, it's not necessary to ban this, because an
- # assert_scalar with constant arguments can be validly run
- # with CUDA graphs, but the operator is also pointless with
- # constant arguments, so might as well ban
- "aten._assert_scalar",
- ]
- )
- def get_first_incompatible_cudagraph_node(
- gm: torch.fx.GraphModule,
- ) -> Optional[torch.fx.Node]:
- from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
- for node in gm.graph.nodes:
- if is_cudagraph_unsafe_fx_node(node):
- return node
- if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val):
- return node
- return None
- def output_node(gm: torch.fx.GraphModule) -> Node:
- """Get the output node from an FX graph"""
- last_node = next(iter(reversed(gm.graph.nodes)))
- assert last_node.op == "output"
- return last_node
- def get_all_devices(gm: torch.fx.GraphModule) -> OrderedSet[torch.device]:
- placeholder_nodes = gm.graph.find_nodes(op="placeholder")
- input_devices: OrderedSet[torch.device] = OrderedSet(
- node.meta["val"].device
- for node in placeholder_nodes
- if isinstance(node.meta.get("val"), torch.Tensor)
- )
- out_arg = output_node(gm).args[0] # type: ignore[union-attr]
- out_args = out_arg if isinstance(out_arg, tuple) else (out_arg,)
- out_devices: OrderedSet[torch.device] = OrderedSet(
- arg.meta["val"].device
- for arg in out_args
- if isinstance(arg, torch.fx.Node)
- and isinstance(arg.meta.get("val"), torch.Tensor)
- )
- return input_devices | out_devices
- import gc
- def unload_xpu_triton_pyds() -> None:
- # unload __triton_launcher.pyd
- for module_name in list(sys.modules.keys()):
- if not module_name.startswith("torch._inductor.runtime.compile_tasks."):
- continue
- m = sys.modules[module_name]
- for attr_name in m.__dict__:
- if attr_name.startswith("triton_"):
- kernel = getattr(m, attr_name)
- if isinstance(
- kernel, torch._inductor.runtime.triton_heuristics.CachingAutotuner
- ):
- for result in kernel.compile_results:
- if isinstance(
- result,
- torch._inductor.runtime.triton_heuristics.TritonCompileResult,
- ):
- # pyrefly: ignore [missing-attribute]
- result.kernel.run.mod.__del__()
- del sys.modules[module_name]
- # unload spirv_utils.pyd
- if "triton.runtime.driver" in sys.modules:
- mod = sys.modules["triton.runtime.driver"]
- del type(mod.driver.active.utils).instance
- del mod.driver.active.utils
- gc.collect()
- _registered_caches: list[Any] = []
- def clear_on_fresh_cache(obj: Any) -> Any:
- """
- Use this decorator to register any caches that should be cache_clear'd
- with fresh_cache().
- """
- if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear):
- raise AttributeError(f"{obj} does not have a cache_clear method")
- _registered_caches.append(obj)
- return obj
- def clear_caches() -> None:
- """
- Clear all registered caches.
- """
- for obj in _registered_caches:
- obj.cache_clear()
- @contextlib.contextmanager
- def _set_env(key: str, value: str) -> Iterator[None]:
- """Thread-safe env var set/restore using atomic C-level lookups.
- We avoid mock.patch.dict(os.environ, ...) because it internally calls
- os.environ.copy(), which iterates all env var keys then fetches values in
- separate steps. That approach is not atomic and can race with background threads
- (e.g. Triton async compilation) modifying the environment, causing KeyError,
- so we use os.environ.get() for individual keys which is an atomic C-level lookup.
- """
- old = os.environ.get(key)
- try:
- os.environ[key] = value
- yield
- finally:
- if old is None:
- os.environ.pop(key, None)
- else:
- os.environ[key] = old
- @contextlib.contextmanager
- def fresh_cache(
- cache_entries: Optional[dict[str, Any]] = None,
- dir: Optional[str] = None,
- delete: bool = True,
- ) -> Iterator[None]:
- """
- Contextmanager that provides a clean tmp cachedir for pt2 caches.
- Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes
- generated with this cache instance.
- """
- clear_caches()
- from torch._inductor.cpp_builder import normalize_path_separator
- inductor_cache_dir = normalize_path_separator(tempfile.mkdtemp(dir=dir))
- try:
- with _set_env("TORCHINDUCTOR_CACHE_DIR", inductor_cache_dir):
- log.debug("Using inductor cache dir %s", inductor_cache_dir)
- triton_cache_dir = normalize_path_separator(
- os.path.join(inductor_cache_dir, "triton")
- )
- with _set_env("TRITON_CACHE_DIR", triton_cache_dir):
- yield
- if isinstance(cache_entries, dict):
- assert len(cache_entries) == 0, "expected empty cache_entries dict"
- if os.path.exists(triton_cache_dir):
- files = os.listdir(triton_cache_dir)
- cache_entries.update(
- {
- f: os.path.getsize(os.path.join(triton_cache_dir, f))
- for f in files
- if ".lock" not in f
- }
- )
- if delete:
- if is_windows() and torch.xpu.is_available():
- unload_xpu_triton_pyds()
- shutil.rmtree(
- inductor_cache_dir,
- # Let's not fail if we can't clean up the temp dir. Also note that for
- # Windows, we can't delete the loaded modules because the module binaries
- # are open.
- ignore_errors=is_windows(),
- onerror=lambda func, path, exc_info: log.warning(
- "Failed to remove temporary cache dir at %s",
- inductor_cache_dir,
- exc_info=exc_info,
- ),
- )
- except Exception:
- log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
- raise
- finally:
- clear_caches()
- # Deprecated functions -- only keeping them for BC reasons
- clear_on_fresh_inductor_cache = clear_on_fresh_cache
- clear_inductor_caches = clear_caches
- fresh_inductor_cache = fresh_cache
- def argsort(seq: Sequence[Any], *, reverse: bool = False) -> list[int]:
- getter = seq.__getitem__
- a_r = range(len(seq))
- # preserve original order for equal strides
- # e.g. if strides are [32, 8, 8, 1]
- # argsort -> [3, 2, 1, 0], rather than
- # [3, 1, 2, 0]
- # i.e. for equal strides in ascending order (reverse=False) an
- # inner dimension should come before an outer dimension, and vice versa
- # for descending
- sort_idx = list(sorted(a_r, key=getter, reverse=True)) # noqa: C413
- if not reverse:
- return list(reversed(sort_idx))
- return sort_idx
- def argsort_sym(
- shape_env: ShapeEnv,
- seq: Sequence[Union[int, torch.SymInt, sympy.Expr]],
- *,
- reverse: bool = False,
- ) -> list[int]:
- def cmp(a: tuple[int, sympy.Expr], b: tuple[int, sympy.Expr]) -> int:
- a_idx, a_val = a
- b_idx, b_val = b
- def evaluate(expr: Union[bool, torch.SymInt, sympy.Expr]) -> bool:
- if isinstance(expr, bool):
- return expr
- return shape_env.evaluate_expr(expr, size_oblivious=True)
- if evaluate(a_val < b_val):
- return -1
- if evaluate(a_val > b_val):
- return 1
- # If strides are the same, prefer the original order.
- # (this matches argsort's algorithm).
- # For strides = [2048, 2048, 16, 1], this is
- # [3, 2, 1, 0].
- if a_idx < b_idx:
- return 1
- if a_idx > b_idx:
- return -1
- return 0
- # Strategy: convert all symints to sympy.Expr, then use a custom comparator
- exprs = [
- (idx, s.node.expr if isinstance(s, torch.SymInt) else s)
- for idx, s in enumerate(seq)
- ]
- exprs = sorted(exprs, key=functools.cmp_to_key(cmp), reverse=reverse)
- result = [idx for idx, _ in exprs]
- return result
- @functools.lru_cache(8)
- def get_dtype_size(dtype: torch.dtype) -> int:
- # TODO: Investigate why uint64 tensor creation causes overflow error:
- # Workaround for RuntimeError in memory size calculation, but underlying cause unclear
- if dtype == torch.uint64:
- return 8
- return torch.empty((), dtype=dtype).element_size()
- class LineContext(NamedTuple):
- context: Any
- @dataclasses.dataclass
- class ValueWithLineMap:
- value: str
- line_map: list[tuple[int, LineContext]]
- class IndentedBuffer:
- tabwidth = 4
- def __init__(self, initial_indent: int = 0) -> None:
- self._lines: list[Union[DeferredLineBase, LineContext, str]] = []
- self._indent = initial_indent
- @contextlib.contextmanager
- def set_tabwidth(self, tabwidth: int) -> Iterator[None]:
- prev = self.tabwidth
- try:
- self.tabwidth = tabwidth
- yield
- finally:
- self.tabwidth = prev
- def getvaluewithlinemap(self) -> ValueWithLineMap:
- buf = StringIO()
- p = 1
- linemap: list[tuple[int, LineContext]] = []
- for li in self._lines:
- if isinstance(li, DeferredLineBase):
- line = li()
- if line is None:
- continue
- elif isinstance(li, LineContext):
- linemap.append((p, li.context))
- continue
- else:
- line = li
- assert isinstance(line, str)
- buf.write(line)
- buf.write("\n")
- p += 1 + line.count("\n")
- return ValueWithLineMap(buf.getvalue(), linemap)
- def getvalue(self) -> str:
- return self.getvaluewithlinemap().value
- def getrawvalue(self) -> str:
- buf = StringIO()
- for li in self._lines:
- if isinstance(li, DeferredLineBase):
- line = li()
- if line is None:
- continue
- elif isinstance(li, LineContext):
- continue
- else:
- line = li
- assert isinstance(line, str)
- # backslash implies line continuation
- if line.endswith("\\"):
- buf.write(line[:-1])
- else:
- buf.write(line)
- buf.write("\n")
- return buf.getvalue()
- def clear(self) -> None:
- self._lines.clear()
- def __bool__(self) -> bool:
- return bool(self._lines)
- def prefix(self) -> str:
- return " " * (self._indent * self.tabwidth)
- def newline(self) -> None:
- self.writeline("\n")
- def writeline(self, line: Union[LineContext, DeferredLineBase, str]) -> None:
- if isinstance(line, LineContext):
- self._lines.append(line)
- elif isinstance(line, DeferredLineBase):
- self._lines.append(line.with_prefix(self.prefix()))
- elif line.strip():
- self._lines.append(f"{self.prefix()}{line}")
- else:
- self._lines.append("")
- def writelines(
- self, lines: Sequence[Union[LineContext, DeferredLineBase, str]]
- ) -> None:
- for line in lines:
- self.writeline(line)
- def indent(self, offset: int = 1) -> contextlib.AbstractContextManager[None]:
- @contextlib.contextmanager
- def ctx() -> Iterator[None]:
- self._indent += offset
- try:
- yield
- finally:
- self._indent -= offset
- return ctx()
- def do_indent(self, offset: int = 1) -> None:
- self._indent += offset
- def do_unindent(self, offset: int = 1) -> None:
- self._indent -= offset
- def splice(
- self, other_code: Union[IndentedBuffer, str], strip: bool = False
- ) -> None:
- if isinstance(other_code, IndentedBuffer):
- dedent = float("inf")
- for line in other_code._lines:
- if not isinstance(line, LineContext) and line:
- dedent = min(dedent, len(line) - len(line.lstrip()))
- if math.isinf(dedent):
- dedent = 0
- for line in other_code._lines:
- if isinstance(line, LineContext):
- self._lines.append(line)
- else:
- IndentedBuffer.writeline(self, line[int(dedent) :])
- else:
- other_code = textwrap.dedent(other_code)
- if strip:
- other_code = other_code.lstrip()
- if not other_code:
- return
- other_code = other_code.rstrip()
- for s in other_code.split("\n"):
- self.writeline(s)
- def map(self, func: Callable[[Any], Any]) -> IndentedBuffer:
- res = IndentedBuffer(initial_indent=self._indent)
- res._lines = [func(line) for line in self._lines]
- return res
- def __repr__(self) -> str:
- return f"{type(self)}({self.getvalue()})"
- def __add__(self, other: Self) -> IndentedBuffer:
- assert self._indent == other._indent
- res = IndentedBuffer(initial_indent=self._indent)
- # TODO(rec): or should this be self.__class__(initial_indent=self._indent)?
- res.writelines(self._lines)
- res.writelines(other._lines)
- return res
- def contains(self, new_line: Union[DeferredLineBase, LineContext, str]) -> bool:
- return new_line in self._lines
- class FakeIndentedBuffer(IndentedBuffer):
- def __init__(self) -> None:
- super().__init__()
- def __getattribute__(self, name: str) -> Any:
- if name == "__class__": # Allow access to the class attribute
- return object.__getattribute__(self, name)
- raise RuntimeError(
- f"Tried to call self.{name} on FakeIndentedBuffer. This buffer"
- "is currently used on TritonTemplateKernel to prevent actual"
- "writes to the body without explicitly specifying the body with"
- "`TritonTemplateKernel.set_subgraph_body(name)`"
- )
- @contextlib.contextmanager
- def restore_stdout_stderr() -> Iterator[None]:
- initial_stdout, initial_stderr = sys.stdout, sys.stderr
- try:
- yield
- finally:
- sys.stdout, sys.stderr = initial_stdout, initial_stderr
- class DeferredLineBase:
- """A line that can be 'unwritten' at a later time"""
- def __init__(self, line: str):
- if not line.strip():
- line = ""
- self.line = line
- def __call__(self) -> Union[str, None]:
- """Returns either self.line or None to indicate the line has been 'unwritten'"""
- raise NotImplementedError
- def _new_line(self, line: str) -> Self:
- """Returns a new deferred line with the same condition"""
- raise NotImplementedError
- def with_prefix(self, prefix: str) -> Self:
- return self._new_line(f"{prefix}{self.line}")
- def lstrip(self) -> Self:
- return self._new_line(self.line.lstrip())
- def __getitem__(self, index: Union[int, slice]) -> Self:
- return self._new_line(self.line[index])
- def __bool__(self) -> bool:
- return bool(self.line)
- def __len__(self) -> int:
- return len(self.line)
- class DelayReplaceLine(DeferredLineBase):
- """At end of codegen call `line.replace(key, value_fn())`"""
- def __init__(self, key: str, value_fn: Callable[[], str], line: str):
- super().__init__(line)
- self.key = key
- self.value_fn = value_fn
- def __call__(self) -> str:
- return self.line.replace(self.key, self.value_fn())
- def _new_line(self, line: str) -> DelayReplaceLine:
- return DelayReplaceLine(self.key, self.value_fn, line)
- @functools.cache
- def is_big_gpu(index_or_device: Union[int, torch.device] = 0) -> bool:
- if isinstance(index_or_device, torch.device):
- device = index_or_device
- else:
- device = torch.device(get_gpu_type(), index_or_device)
- prop = DeviceProperties.create(device)
- # SM logic is not relevant to ROCm gpus
- # Arbitrarily skipping the older models
- if torch.version.hip:
- assert prop.major is not None
- if prop.major < 9 or prop.major == 10:
- log.warning("GPU arch does not support max_autotune_gemm mode usage")
- return False
- return True
- min_sms = 16 if device.type == "xpu" else 68 # 3080
- avail_sms = prop.multi_processor_count
- if avail_sms < min_sms:
- log.warning(
- "Not enough SMs to use max_autotune_gemm mode",
- extra={"min_sms": min_sms, "avail_sms": avail_sms},
- )
- return False
- return True
- @functools.lru_cache
- def get_max_num_sms() -> int:
- if torch.xpu.is_available():
- return torch.xpu.get_device_properties().gpu_subslice_count
- return torch.cuda.get_device_properties("cuda").multi_processor_count
- @functools.lru_cache
- def using_b200() -> bool:
- """Returns true if the device is a NVIDIA B200, otherwise returns false."""
- if not torch.cuda.is_available():
- return False
- # compute capability 10.0 or 10.0a is NVIDIA B200
- device_properties = torch.cuda.get_device_properties(torch.cuda.current_device())
- return device_properties.major == 10
- def get_num_sms() -> int:
- """Handle experimental carveout if set otherwise return hardware SM count"""
- # TODO we need to properly guard on this global
- if torch.xpu.is_available():
- return get_max_num_sms()
- carveout = torch._C._get_sm_carveout_experimental()
- return get_max_num_sms() - (carveout if carveout is not None else 0)
- def get_tma_workspace_arg(
- num_tma_descriptors: int,
- device: torch.device,
- num_programs: Optional[int] = None,
- ) -> WorkspaceArg:
- """Builds and returns a WorkspaceArg for the device side TMA workspace buffer."""
- from .codegen.common import WorkspaceArg, WorkspaceZeroMode
- if num_programs is None:
- num_programs = get_num_sms()
- zero_mode = WorkspaceZeroMode.from_bool(False)
- size = num_programs * num_tma_descriptors * TMA_DESCRIPTOR_SIZE
- return WorkspaceArg(
- count=size,
- zero_mode=zero_mode,
- device=device,
- outer_name=WorkspaceArg.unique_name(),
- )
- def _use_template_for_gpu(
- layout: Layout, allowed_layout_dtypes: list[torch.dtype]
- ) -> bool:
- if layout.dtype not in allowed_layout_dtypes:
- log.debug(
- "Not using template since dtype %s is not in allowed layout dtypes %s",
- layout.dtype,
- allowed_layout_dtypes,
- )
- return (
- is_gpu(layout.device.type)
- and layout.dtype in allowed_layout_dtypes
- and is_big_gpu(layout.device)
- )
- def _use_autotune_backend(backend: str) -> bool:
- return backend.upper() in [
- x.strip() for x in config.max_autotune_gemm_backends.upper().split(",")
- ]
- def _use_conv_autotune_backend(backend: str) -> bool:
- return backend.upper() in [
- x.strip() for x in config.max_autotune_conv_backends.upper().split(",")
- ]
- def use_triton_template(
- layout: Layout,
- *,
- enable_int32: bool = False,
- enable_float8: bool = False,
- check_max_autotune: bool = True,
- ) -> bool:
- from .codegen.common import BackendFeature, has_backend_feature
- layout_dtypes = [torch.float16, torch.bfloat16, torch.float32]
- if enable_int32:
- layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32]
- if enable_float8:
- layout_dtypes.extend([torch.float8_e4m3fn, torch.float8_e5m2])
- return (
- (
- (
- is_gpu(layout.device.type)
- and _use_template_for_gpu(layout, layout_dtypes)
- )
- or (layout.device.type == "cpu" and layout.dtype in layout_dtypes)
- )
- # some callers handle max-autotune checking externally
- and (config.max_autotune or config.max_autotune_gemm or not check_max_autotune)
- and _use_autotune_backend("TRITON")
- and has_backend_feature(layout.device, BackendFeature.TRITON_TEMPLATES)
- )
- def can_use_tma(
- *matrices: IRNode, output_layout: Optional[Layout] = None, add_guards: bool = False
- ) -> bool:
- """
- Return True iff *all* supplied tensors satisfy the CUDA TMA constraints
- that Triton relies on today.
- * https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
- A tensor is accepted when:
- * 1 ≤ rank ≤ 5 (cuTensorMapEncodeTiled)
- * dtype in _TMA_SUPPORTED_DTYPES (CUtensorMapDataType enum)
- * Base pointer 16-byte aligned
- * Exactly one contiguous ("inner") dim with stride 1
- * All "outer" dims have 16-byte aligned strides
- * Inner dim size × itemsize is a multiple of 16
- * For 1-byte dtypes (e.g. FP8), inner dim ≥ 32
- """
- from torch.utils._triton import has_triton_tma_device
- from .virtualized import V
- def _aligned(expr_bytes: Union[int, sympy.Expr]) -> bool:
- return V.graph.sizevars.statically_known_multiple_of(expr_bytes, TMA_ALIGNMENT)
- def _is_tma_compatible_layout(layout: Optional[Layout]) -> bool:
- if layout is None:
- return True
- sizes = layout.size
- strides = layout.stride
- dtype = layout.dtype
- # Verify the output is 16-byte aligned
- if not _aligned(layout.offset):
- return False
- return _is_tma_compatible(sizes, strides, dtype)
- def _is_tma_compatible_matrix(m: IRNode) -> bool:
- sizes = m.get_size()
- strides = m.get_stride()
- dtype = m.get_dtype()
- # Base pointer 16-byte aligned
- if m.get_name() in V.graph.unaligned_buffers:
- return False
- if (m_device := m.get_device()) is not None and m_device.type == "xpu":
- return _is_tma_compatible_xpu(sizes, strides, dtype)
- return _is_tma_compatible(sizes, strides, dtype)
- def _is_tma_compatible(
- sizes: Sequence[sympy.Expr],
- strides: Sequence[_IntLike],
- dtype: torch.dtype,
- ) -> bool:
- rank = len(sizes)
- itemsize = dtype.itemsize
- if rank < 1 or rank > 5:
- return False
- if dtype not in _TMA_SUPPORTED_DTYPES:
- return False
- if add_guards:
- sizes_i = V.graph.sizevars.guard_int_seq(sizes)
- strides_i = V.graph.sizevars.guard_int_seq(strides)
- else:
- sizes_i = [V.graph.sizevars.symbolic_hint(s) for s in sizes]
- strides_i = [V.graph.sizevars.symbolic_hint(st) for st in strides]
- # Find the single contiguous ("inner") dim
- inner = [
- i
- for i, st in enumerate(strides_i)
- if V.graph.sizevars.statically_known_equals(st, 1)
- ]
- if len(inner) != 1:
- return False
- inner_idx = inner[0]
- # All "outer" dims must have 16-byte aligned strides
- for i, st in enumerate(strides_i):
- if i == inner_idx:
- continue
- if not _aligned(st * itemsize):
- return False
- # Inner dim byte width must be a multiple of 16 B
- inner_dim = sizes_i[inner_idx]
- if not _aligned(inner_dim * itemsize):
- return False
- # 1-byte dtypes (FP8 etc.) need inner dim ≥ 32 for tensor core alignment
- if itemsize == 1 and not V.graph.sizevars.statically_known_geq(inner_dim, 32):
- return False
- return True
- def _is_tma_compatible_xpu(
- sizes: Sequence[sympy.Expr],
- strides: Sequence[_IntLike],
- dtype: torch.dtype,
- ) -> bool:
- # Make sure the last dimension is contiguous
- last_stride = strides[-1]
- last_stride_hint = V.graph.sizevars.symbolic_hint(last_stride)
- if not V.graph.sizevars.statically_known_equals(last_stride_hint, 1):
- return False
- # Triton's type of index is uint32, so all dimensions must fit in uint32
- MAX_UINT32 = 2**32 - 1
- for size in sizes:
- size_hint = V.graph.sizevars.symbolic_hint(size)
- if V.graph.sizevars.statically_known_gt(size_hint, MAX_UINT32):
- return False
- return True
- return (
- has_triton_tma_device()
- and all(_is_tma_compatible_matrix(m) for m in matrices)
- and _is_tma_compatible_layout(output_layout)
- )
- def use_triton_tma_template(
- *matrices: IRNode, output_layout: Layout, add_guards: bool = False
- ) -> bool:
- layout = output_layout if config.triton.enable_template_tma_store else None
- return (
- all(len(m.get_size()) == 2 for m in matrices)
- and can_use_tma(*matrices, output_layout=layout, add_guards=add_guards)
- and config.triton.enable_persistent_tma_matmul
- )
- def use_triton_blackwell_tma_template(
- *matrices: IRNode, output_layout: Layout, add_guards: bool = False
- ) -> bool:
- if not use_triton_tma_template(
- *matrices, output_layout=output_layout, add_guards=add_guards
- ):
- return False
- from torch.utils._triton import has_triton_tensor_descriptor_host_tma
- from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
- # Blackwell template require the tensor descriptor API, not the experimental API.
- return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch()
- def use_triton_scaling_template(
- scale_option_a: ScalingType,
- scale_option_b: ScalingType,
- scaling_types: list[ScalingType],
- ) -> bool:
- return scale_option_a in scaling_types and scale_option_b in scaling_types
- @functools.lru_cache(maxsize=1)
- def ensure_cute_available() -> bool:
- """Check if CuTeDSL is importable; cache the result for reuse.
- Call ensure_cute_available.cache_clear() after installing CuTeDSL
- in the same interpreter to retry the import.
- """
- try:
- return importlib.util.find_spec("cutlass") is not None
- except ImportError:
- return False
- @functools.lru_cache(maxsize=1)
- def ensure_nv_universal_gemm_available() -> bool:
- """Check if NVIDIA Universal GEMM (cutlass_api) is importable; cache the result for reuse.
- Call ensure_nv_universal_gemm_available.cache_clear() after installing cutlass_api
- in the same interpreter to retry the import.
- """
- try:
- return importlib.util.find_spec("cutlass_api") is not None
- except ImportError:
- return False
- @functools.lru_cache(maxsize=1)
- def ensure_nvmatmul_heuristics_available() -> bool:
- """Check if nvMatmulHeuristics is importable; cache the result for reuse.
- nvMatmulHeuristics provides performance model-based kernel selection
- for NVIDIA GEMM operations.
- Call ensure_nvmatmul_heuristics_available.cache_clear() after installing
- nvMatmulHeuristics in the same interpreter to retry the import.
- """
- try:
- return importlib.util.find_spec("nvMatmulHeuristics") is not None
- except ImportError:
- return False
- def use_blackwell_cutedsl_grouped_mm(
- mat_a: Any,
- mat_b: Any,
- layout: Layout,
- a_is_2d: bool,
- b_is_2d: bool,
- offs: Optional[Any],
- bias: Optional[Any],
- scale_result: Optional[Any],
- ) -> bool:
- """
- Returns True if we can use the blackwell kernel for grouped mm.
- Required conditions:
- 1. CuTeDSL backend is enabled
- 2. CuTeDSL is available
- 3. We are on a blackwell arch
- 4. The dtype is bf16
- 5. Max autotune or max autotune gemm is enabled
- 6. A, B, and the output are 16B aligned
- 7. We are not using dynamic shapes
- 8. A is 2d
- 9. B is 3d
- 10. Offsets are provided
- 11. Bias and Scale are not provided
- """
- if not ensure_cute_available():
- return False
- if not _use_autotune_backend("CUTEDSL"):
- return False
- from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch
- if not is_gpu(layout.device.type):
- return False
- if not is_datacenter_blackwell_arch():
- return False
- layout_dtypes = [torch.bfloat16]
- if not _use_template_for_gpu(layout, layout_dtypes):
- return False
- if not (config.max_autotune or config.max_autotune_gemm):
- return False
- # Checks for 16B ptr and stride alignment
- if not can_use_tma(mat_a, mat_b, output_layout=layout):
- return False
- if any(is_dynamic(x) for x in [mat_a, mat_b]):
- return False
- if not a_is_2d or b_is_2d:
- return False
- if offs is None:
- return False
- if bias is not None or scale_result is not None:
- return False
- return True
- def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
- from .virtualized import V
- gemm_size = V.graph.sizevars.optimization_hint(m * n * k, fallback=-1)
- if gemm_size <= 0 or gemm_size < config.cutlass.cutlass_backend_min_gemm_size:
- return False
- from .codegen.cutlass.utils import try_import_cutlass
- # Do not use cutlass template on ROCm
- if torch.version.hip:
- return False
- # output dtype
- # FP32 not supported: https://github.com/pytorch/pytorch/issues/145952
- layout_dtypes = [torch.float16, torch.bfloat16, torch.int32]
- res = (
- _use_template_for_gpu(layout, layout_dtypes)
- and (config.max_autotune or config.max_autotune_gemm)
- and _use_autotune_backend("CUTLASS")
- )
- if res:
- if not try_import_cutlass():
- log.warning(
- "Failed to import CUTLASS lib. Please check whether "
- "_inductor.config.cutlass.cutlass_dir %s is set correctly. "
- "Skipping CUTLASS backend for now.",
- config.cutlass.cutlass_dir,
- )
- return False
- return res
- def use_nv_universal_gemm_template(
- layout: Layout,
- m: _IntLike,
- n: _IntLike,
- k: _IntLike,
- mat_a: IRNode,
- mat_b: IRNode,
- offs: Optional[IRNode] = None,
- g: Optional[_IntLike] = None,
- ) -> bool:
- """
- Return True if we can use the NVIDIA Universal GEMM Template.
- Required conditions:
- 1. NVGEMM backend is enabled
- 2. cutlass_api is available
- 3. We are on a NVIDIA GPU
- 4. Max autotune or max autotune gemm is enabled
- 5. Not in AOT Inductor mode (requires runtime JIT compilation)
- 6. Base pointers are 16-byte aligned
- 7. Shape dimensions are not unbacked symbols
- Note:
- - Shape and stride constraints are handled internally by
- cutlass_api.get_kernels() which filters incompatible kernels.
- - GroupedGemm currently only supports TN layout (column-major B).
- Any other layout will act as a noop and fall back to ATen.
- - Dynamic shapes are supported as long as they have hints
- (from example inputs).
- """
- from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
- if not ensure_cute_available():
- return False
- if not ensure_nv_universal_gemm_available():
- return False
- if not _use_autotune_backend("NVGEMM"):
- return False
- from .virtualized import V
- if V.aot_compilation:
- return False
- if layout.device.type != "cuda" or torch.version.hip:
- return False
- if not (config.max_autotune or config.max_autotune_gemm):
- return False
- # cutlass_api can't handle unbacked symbols because it needs to evaluate
- # shape constraints (e.g., stride divisibility by 8, N/K divisibility by 16).
- # Unbacked symbols have no hint values, causing GuardOnDataDependentSymNode errors.
- dims_to_check = [m, n, k]
- if g is not None:
- dims_to_check.append(g)
- if any(has_free_unbacked_symbols(dim) for dim in dims_to_check):
- return False
- # Base pointer must be 16-byte aligned. cutlass_api can't check this at
- # compile time because it only sees FakeTensors without real data pointers.
- tensors_to_check = [mat_a, mat_b]
- if offs is not None:
- tensors_to_check.append(offs)
- if any(t.get_name() in V.graph.unaligned_buffers for t in tensors_to_check):
- return False
- return True
- def _use_cutlass_for_op(op_name: str) -> bool:
- """Check if CUTLASS should be used for the given operation."""
- enabled_ops = config.cutlass.cutlass_enabled_ops.upper()
- if enabled_ops == "ALL":
- return True
- return op_name.upper() in [x.strip() for x in enabled_ops.split(",")]
- _IntLike: TypeAlias = Union[int, sympy.Expr]
- @functools.cache
- def use_decompose_k_choice(
- m: _IntLike, n: _IntLike, k: _IntLike, threshold_multiple: int = 1
- ) -> bool:
- from torch._inductor.virtualized import V
- decompose_k_threshold = config.triton.decompose_k_threshold * threshold_multiple
- return (
- V.graph.sizevars.statically_known_true(
- sympy.And(
- sympy.Ge(k, decompose_k_threshold * m),
- sympy.Ge(k, decompose_k_threshold * n),
- )
- )
- and not V.graph.aot_mode # TODO: Support AOTI for decomposeK
- and not V.graph.cpp_wrapper
- and config.triton.num_decompose_k_splits > 0
- )
- @functools.cache
- def use_contiguous(m: _IntLike, n: _IntLike, k: _IntLike) -> bool:
- """
- Check if we should use the contiguous subgraph transform.
- This transform makes the second matrix contiguous before the matmul.
- """
- contiguous_threshold = config.rocm.contiguous_threshold
- # Similar conditions to decompose_k but for contiguous transform
- from torch._inductor.virtualized import V
- return (
- bool(torch.version.hip) # Only relevant on AMD
- and V.graph.sizevars.statically_known_true(
- sympy.And(
- sympy.Ge(k, contiguous_threshold * m),
- sympy.Ge(k, contiguous_threshold * n),
- )
- )
- and not V.graph.aot_mode
- and not V.graph.cpp_wrapper
- )
- @functools.cache
- def get_k_splits(m: _IntLike, n: _IntLike, k: _IntLike) -> list[int]:
- # To limit compile time
- k_splits_limit = config.triton.num_decompose_k_splits
- # Hand-tuned
- default_k_splits = [16, 32, 64, 128, 256]
- # If k is a sympy expression, we can't do any splitting
- if isinstance(k, sympy.Expr) and not k.is_number:
- return default_k_splits
- elif k_splits_limit == 0:
- return []
- if (isinstance(m, sympy.Expr) and not m.is_number) or (
- isinstance(n, sympy.Expr) and not n.is_number
- ):
- max_k_split = 256
- else:
- max_k_split = min(k // m, k // n)
- min_k_split = 2
- # Get all divisors of k, k has to be divisible by kPart
- divisors = sympy.divisors(k)
- divisors = [
- divisor
- for divisor in divisors
- if divisor <= max_k_split and divisor >= min_k_split
- ]
- pow_of_2_divisors, mul_of_32_divisors, rest_of_splits = [], [], []
- for d in divisors:
- kPart = k // d
- # Smaller than 128 might not even fit in a single tile, BLOCK_K can be 128
- if kPart < 128:
- continue
- # Power of 2 divisors are best performing, conform to hardware
- if (kPart & kPart - 1) == 0 and kPart >= 128:
- pow_of_2_divisors.append(d)
- # Else check if creates a multiple of 32
- elif kPart % 32 == 0:
- mul_of_32_divisors.append(d)
- # otherwise, take the smallest values
- else:
- rest_of_splits.append(d)
- if config.max_autotune_gemm_search_space == "EXHAUSTIVE":
- return pow_of_2_divisors + mul_of_32_divisors + rest_of_splits
- best_splits = pow_of_2_divisors + mul_of_32_divisors + rest_of_splits
- # Otherwise, conform results to k_splits_limit
- return best_splits[:k_splits_limit]
- @functools.cache
- def _rocm_native_device_arch_name(device: str) -> str:
- return torch.cuda.get_device_properties(device).gcnArchName
- @functools.cache
- def try_import_ck_lib() -> tuple[
- Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]
- ]:
- try:
- import ck4inductor # type: ignore[import]
- from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import]
- gen_ops_library,
- gen_ops_preselected,
- )
- from ck4inductor.universal_gemm.op import ( # type: ignore[import]
- CKGemmOperation,
- )
- package_dirname = os.path.dirname(ck4inductor.__file__)
- except ImportError:
- def gen_ops_library() -> list[Any]:
- return []
- def gen_ops_preselected() -> list[Any]:
- return []
- class CKGemmOperation: # type: ignore[no-redef]
- pass
- package_dirname = None
- return package_dirname, gen_ops_library, gen_ops_preselected, CKGemmOperation
- def use_ck_template(layout: Layout) -> bool:
- # config knobs check 1
- if not (config.max_autotune or config.max_autotune_gemm):
- return False
- # platform check
- if not torch.version.hip:
- return False
- # tensors must be on GPU
- if layout.device.type != "cuda":
- return False
- # hardware check
- # if config arch list is not specified, get the native arch from the device properties
- native_arch = _rocm_native_device_arch_name(layout.device)
- requested_archs = {k.split(":")[0]: k for k in config.rocm.arch} or {
- native_arch.split(":")[0]: native_arch
- }
- requested_supported_archs = [
- requested_archs[k]
- for k in requested_archs.keys() & config.rocm.ck_supported_arch
- ]
- if not requested_supported_archs:
- return False
- # supported input dtypes
- if layout.dtype not in [torch.float16, torch.bfloat16, torch.float32]:
- return False
- ck_package_dirname, _, _, _ = try_import_ck_lib()
- if not ck_package_dirname:
- log.warning("Please pip install Composable Kernel package")
- return False
- config.rocm.ck_dir = ck_package_dirname
- return True
- def use_ck_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool:
- from .virtualized import V
- return (
- _use_autotune_backend("CK")
- and use_ck_template(layout)
- and V.graph.sizevars.optimization_hint(m * n * k, fallback=-1) > 0
- )
- def use_ck_tile_gemm_template(layout: Layout, m: int, n: int, k: int) -> bool:
- from .virtualized import V
- return (
- _use_autotune_backend("CKTILE")
- and use_ck_template(layout)
- and V.graph.sizevars.optimization_hint(m * n * k, fallback=-1) > 0
- )
- def use_ck_conv_template(layout: Layout) -> bool:
- return _use_conv_autotune_backend("CK") and use_ck_template(layout)
- def _use_template_for_cpu(layout: Layout) -> bool:
- return (
- config.max_autotune or config.max_autotune_gemm
- ) and layout.device.type == "cpu"
- def use_cpp_bmm_template(
- layout: Layout, mat1: Union[ReinterpretView, Buffer], mat2: IRNode
- ) -> bool:
- from .ir import Layout
- assert isinstance(mat1.layout, Layout)
- # In certain scenarios, such as when the first stride is 0, the entire tensor may not be contiguous.
- # But the 2D matrix within each batch can still be contiguous, allowing us to apply max autotune.
- # So here we specifically check for contiguity within the 2D matrix of each batch.
- mat1_size = mat1.layout.size
- mat1_stride = mat1.layout.stride
- mat1_each_batch_is_contiguous = (
- _use_template_for_cpu(layout)
- and mat1.get_dtype() == torch.float32
- and (len(mat1_size) == 3)
- and (len(mat1_stride) == 3)
- and (mat1_stride[1] == mat1_size[2])
- and (mat1_stride[2] == 1)
- )
- return use_cpp_gemm_template(layout, mat1, mat2, require_constant_mat2=False) and (
- mat1.layout.is_contiguous() or mat1_each_batch_is_contiguous
- )
- def use_cpp_gemm_template(
- layout: Layout,
- mat1: IRNode,
- mat2: IRNode,
- mat2_transposed: bool = False,
- require_constant_mat2: bool = True,
- is_woq_int4: bool = False,
- q_group_size: Optional[int] = None,
- ) -> bool:
- from . import ir
- from .codegen.cpp_micro_gemm import create_micro_gemm
- from .codegen.cpp_utils import get_gemm_template_output_and_compute_dtype
- from .kernel.mm_common import mm_args
- if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"):
- return False
- if not config.cpp.weight_prepack:
- return False
- int8_gemm = mat1.get_dtype() in [torch.uint8, torch.int8]
- layout_dtypes = [torch.float32, torch.bfloat16, torch.half, torch.uint8, torch.int8]
- m, n, k, layout, mat1, mat2 = mm_args(
- mat1,
- mat2,
- out_dtype=layout.dtype if int8_gemm else None,
- mat2_transposed=mat2_transposed,
- use_4x2_dim=is_woq_int4,
- )
- # TODO(jgong5): support dynamic shapes for n or k
- if has_free_symbols((n, k)):
- return False
- if isinstance(mat2, ir.BaseView):
- mat2 = mat2.unwrap_view()
- output_dtype, _ = get_gemm_template_output_and_compute_dtype(mat1.get_dtype())
- micro_gemm = create_micro_gemm(
- "micro_gemm",
- m,
- n,
- k,
- input_dtype=mat1.get_dtype(),
- input2_dtype=mat2.get_dtype(),
- output_dtype=output_dtype,
- num_threads=parallel_num_threads(),
- use_ref=not is_woq_int4,
- q_group_size=q_group_size,
- )
- def is_last_dim_stride1(x: IRNode) -> bool:
- x.freeze_layout()
- return x.get_stride()[-1] == 1
- return (
- layout.dtype in layout_dtypes
- and micro_gemm is not None
- and is_last_dim_stride1(mat1) # TODO(jgong5): support transposed input
- and isinstance(mat2, ir.StorageBox)
- and (mat2.is_module_buffer() or not require_constant_mat2)
- )
- def use_aten_gemm_kernels() -> bool:
- return not (
- config.max_autotune or config.max_autotune_gemm
- ) or _use_autotune_backend("ATEN")
- class DebugDirManager:
- counter = itertools.count(0)
- prev_debug_name: str
- def __init__(self) -> None:
- self.id = next(DebugDirManager.counter)
- def __enter__(self) -> None:
- self.prev_debug_name = torch._dynamo.config.debug_dir_root
- self.new_name = f"{self.prev_debug_name}_tmp_{self.id}"
- torch._dynamo.config.debug_dir_root = self.new_name
- def __exit__(self, *args: Any) -> None:
- shutil.rmtree(self.new_name)
- torch._dynamo.config.debug_dir_root = self.prev_debug_name
- def run_and_get_code(
- fn: Callable[P, _T],
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> tuple[_T, list[str]]:
- from .graph import GraphLowering
- source_codes: OrderedSet[str] = OrderedSet()
- def save_output_code(code: str) -> None:
- source_codes.add(code)
- with mock.patch.object(GraphLowering, "save_output_code", save_output_code):
- torch._dynamo.reset()
- result = fn(*args, **kwargs)
- return result, list(source_codes)
- def run_and_get_kernels(
- fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
- ) -> tuple[_T, list[str]]:
- remove_quote = kwargs.pop("remove_quote", False)
- # pyrefly: ignore [bad-argument-type]
- result, source_codes = run_and_get_code(fn, *args, **kwargs)
- kernels = []
- for code in source_codes:
- kernels.extend(re.findall(r"'''.*?'''", code, re.DOTALL))
- if remove_quote:
- kernels = [kernel[3:-3] for kernel in kernels]
- return result, kernels
- def run_fw_bw_and_get_code(fn: Callable[..., Any]) -> tuple[Any, list[str]]:
- def run_with_backward() -> Any:
- result = fn()
- result.sum().backward()
- return result
- return run_and_get_code(run_with_backward)
- def get_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> list[str]:
- """Get the inductor-generated code, but skip any actual compilation or running."""
- from .graph import GraphLowering
- source_codes: list[str] = []
- def save_output_code(code: str) -> None:
- source_codes.append(code)
- def patched_compile_to_module(self: GraphLowering) -> Any:
- class DummyModule:
- """This is empty to replace the generated triton module"""
- def __init__(self) -> None:
- pass
- def call(self, *args: Any, **kwargs: Any) -> None:
- # Don't do anything when called
- pass
- wrapper_code, kernel_code = (
- self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen()
- )
- # Skip all the actual compiling.
- save_output_code(wrapper_code.value)
- if kernel_code:
- save_output_code(kernel_code.value)
- return DummyModule()
- with (
- mock.patch.object(
- GraphLowering, "compile_to_module", patched_compile_to_module
- ),
- mock.patch.object(GraphLowering, "save_output_code", save_output_code),
- ):
- torch._dynamo.reset()
- # Note the return here is None
- _ = fn(*args, **kwargs)
- return source_codes
- def get_triton_code(fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs) -> str:
- # pyrefly: ignore [bad-argument-type]
- source_codes = get_code(fn, *args, **kwargs)
- # Can have two outputs if backwards was eagerly compiled
- assert 1 <= len(source_codes) <= 2, (
- f"expected one or two code outputs got {len(source_codes)}"
- )
- return source_codes[0]
- def run_and_get_triton_code(
- fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
- ) -> str:
- # pyrefly: ignore [bad-argument-type]
- _, source_codes = run_and_get_code(fn, *args, **kwargs)
- # Can have two outputs if backwards was eagerly compiled
- assert 1 <= len(source_codes) <= 2, (
- f"expected one or two code outputs got {len(source_codes)}"
- )
- return source_codes[0]
- def run_and_get_graph_lowering(
- fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
- ) -> tuple[Any, list[GraphLowering]]:
- from torch._inductor.graph import GraphLowering
- from torch._inductor.output_code import CompiledFxGraph
- real_init = CompiledFxGraph.__init__
- graph_lowerings = []
- def fake_init(*args: Any, **kwargs: Any) -> None:
- real_init(*args, **kwargs)
- graph = args[2]
- assert isinstance(graph, GraphLowering)
- graph_lowerings.append(graph)
- with mock.patch.object(CompiledFxGraph, "__init__", fake_init):
- result = fn(*args, **kwargs)
- return result, graph_lowerings
- @contextlib.contextmanager
- def override_lowering(
- aten_op: Callable[..., Any], override_fn: Callable[..., Any]
- ) -> Iterator[None]:
- """
- Override the lowering of aten_op with override_fn.
- The first argument of override_fn is the original lowering fn.
- """
- from torch._inductor import lowering
- orig_fn = lowering.lowerings[aten_op]
- try:
- lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn)
- yield
- finally:
- lowering.lowerings[aten_op] = orig_fn
- def add_scheduler_init_hook(
- pre_fn: Callable[..., Any], post_fn: Optional[Callable[..., Any]] = None
- ) -> Any:
- """
- Add hook functions to be called at the beginning and end of Scheduler.__init__.
- Used for unit tests.
- """
- from torch._inductor.scheduler import Scheduler
- orig_fn = Scheduler.__init__
- def wrapper(scheduler: Any, nodes: Any) -> Any:
- pre_fn(scheduler, nodes)
- out = orig_fn(scheduler, nodes)
- if post_fn:
- post_fn(scheduler, nodes)
- return out
- return unittest.mock.patch.object(Scheduler, "__init__", wrapper)
- def developer_warning(msg: str) -> None:
- """
- Warnings that will be actionable for PyTorch developers, but not
- end users. Allows us to easily disable them in stable releases but
- keep them on for nightly builds.
- """
- if config.developer_warnings:
- log.warning(msg)
- else:
- log.info(msg)
- def get_benchmark_name() -> Optional[str]:
- """
- An experimental API used only when config.benchmark_kernel is true.
- The benchmark name is only available at codegen time. So we can not
- directly call it in benchmark_all_kernels which is run after codegen.
- The function assumes the argument after --only is the benchmark name.
- It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc
- scripts, this function may return None.
- There are 2 flavors of --only argument we need handle:
- 1. --only model_name
- 2. --only=model_name
- """
- try:
- idx = sys.argv.index("--only")
- if (
- idx + 1 < len(sys.argv)
- and len(sys.argv[idx + 1]) > 0
- and sys.argv[idx + 1][0] != "-"
- ):
- return sys.argv[idx + 1]
- except ValueError:
- pass
- for arg in sys.argv:
- if arg.startswith("--only="):
- return arg[len("--only=") :]
- return None
- def is_ones(items: Sequence[Any]) -> bool:
- return all(x == 1 for x in items)
- def is_zeros(items: Sequence[Any]) -> bool:
- return all(x == 0 for x in items)
- def is_cpu_device(inputs: Sequence[torch.Tensor]) -> bool:
- return all(
- item.device == torch.device("cpu")
- for item in inputs
- if isinstance(item, torch.Tensor)
- )
- def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
- assert isinstance(val, sympy.Expr), (
- "only support sympy.Expr as input to get_sympy_Expr_dtype"
- )
- if val.is_integer: # type: ignore[attr-defined]
- return torch.int64
- else:
- return torch.float64
- @contextlib.contextmanager
- def maybe_profile(should_profile: bool, *args: Any, **kwargs: Any) -> Iterator[Any]:
- if should_profile:
- with torch.profiler.profile(*args, **kwargs) as p:
- yield p
- else:
- yield
- def parallel_num_threads() -> int:
- threads = config.cpp.threads
- if threads < 1:
- threads = torch.get_num_threads()
- return threads
- @functools.cache
- def get_backend_num_stages() -> int:
- from .runtime.triton_helpers import get_backend_options
- options = get_backend_options()
- return options.get("num_stages", 2 if torch.version.hip else 3)
- @functools.cache
- def get_device_tflops(dtype: torch.dtype) -> float:
- """
- We don't want to throw errors in this function. First check to see if the device is in device_info.py,
- then fall back to the inaccurate triton estimation.
- """
- ds_tops = datasheet_tops(
- dtype, is_tf32=torch.backends.cuda.matmul.fp32_precision == "tf32"
- )
- if ds_tops is not None:
- return ds_tops
- from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops
- SM80OrLater = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (
- 8,
- 0,
- )
- assert dtype in (torch.float16, torch.bfloat16, torch.float32)
- if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"):
- # Triton API change in https://github.com/triton-lang/triton/pull/2293
- from torch._utils_internal import max_clock_rate
- sm_clock = max_clock_rate()
- if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
- return get_max_tensorcore_tflops(dtype, sm_clock)
- if torch.backends.cuda.matmul.fp32_precision == "tf32":
- return get_max_tensorcore_tflops(torch.float32, sm_clock)
- else:
- return get_max_simd_tflops(torch.float32, sm_clock)
- else:
- if dtype in (torch.float16, torch.bfloat16) and SM80OrLater:
- return get_max_tensorcore_tflops(dtype)
- if torch.backends.cuda.matmul.fp32_precision == "tf32":
- return get_max_tensorcore_tflops(torch.float32)
- else:
- return get_max_simd_tflops(torch.float32)
- @functools.cache
- def get_gpu_dram_gbps() -> int:
- from triton.testing import get_dram_gbps
- return get_dram_gbps()
- def get_gpu_shared_memory() -> int:
- from triton.runtime import driver
- return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0)
- def get_max_numwarps() -> int:
- if torch.cuda.is_available():
- warp_size = torch.cuda.get_device_properties().warp_size
- # pyrefly: ignore [missing-attribute]
- max_threads_per_block = torch.cuda.get_device_properties().max_threads_per_block
- else:
- # Defaults
- warp_size = 32
- max_threads_per_block = 1024
- return max_threads_per_block // warp_size
- def is_welford_reduction(reduction_type: str) -> bool:
- return reduction_type.startswith("welford")
- def reduction_num_outputs(reduction_type: str) -> int:
- if is_welford_reduction(reduction_type):
- return 3
- elif reduction_type == "online_softmax_reduce":
- return 2
- else:
- return 1
- def is_linux() -> bool:
- return platform.system() == "Linux"
- def is_windows() -> bool:
- return sys.platform == "win32"
- def has_free_symbols(itr: Iterable[Any]) -> bool:
- return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr)
- def is_dynamic(*args: Any) -> bool:
- from . import ir
- for t in args:
- if isinstance(
- t, (ir.TensorBox, ir.StorageBox, ir.BaseView, ir.ComputedBuffer, ir.Buffer)
- ):
- if has_free_symbols(t.maybe_get_size() or ()) or has_free_symbols(
- t.maybe_get_stride() or ()
- ):
- return True
- elif not isinstance(t, ir.IRNode):
- continue
- else:
- raise TypeError(f"unexpected type for is_dynamic {type(t)}")
- return False
- # Placeholder strings used in triton codegen.
- class Placeholder(enum.Enum):
- # The placeholder for the actual name of a triton kernel.
- # e.g. for "def triton_" it would be "triton_"
- KERNEL_NAME = "KERNEL_NAME"
- # The descriptive name of the triton kernel; when unique_kernel_names = False, this
- # placeholder will be replaced with a string with more information.
- DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME"
- def pass_execution_and_save(
- func: Callable[..., Any], gm: GraphModule, inp: Sequence[Any], msg: str
- ) -> None:
- from .pattern_matcher import stable_topological_sort
- with tempfile.NamedTemporaryFile(
- mode="w",
- encoding="utf-8",
- ) as f:
- before_io = io.StringIO()
- after_io = io.StringIO()
- ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp)
- print(f"Before:\n{gm.graph}", file=f)
- print(gm.graph, file=before_io)
- start_time = datetime.now()
- with GraphTransformObserver(gm, msg):
- func(gm.graph)
- time_elapsed = datetime.now() - start_time
- # recompile graph
- stable_topological_sort(gm.graph)
- gm.graph.lint()
- gm.recompile()
- print(f"After:\n{gm.graph}", file=f)
- print(gm.graph, file=after_io)
- t = before_io.getvalue() == after_io.getvalue()
- log.info(
- "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s",
- msg,
- f.name,
- t,
- time_elapsed,
- )
- def is_multi_outputs_template(input_buf: Optional[Union[Buffer, Operation]]) -> bool:
- """
- Check if input buffer is a multi-outputs template buffer
- """
- from . import ir
- return isinstance(input_buf, ir.CppTemplateBuffer) and isinstance(
- input_buf.layout, ir.MultiOutputLayout
- )
- def is_output_of_multi_outputs_template(
- input_buf: Optional[Union[Buffer, Operation]],
- ) -> bool:
- """
- Check if input buffer is a output of multi-outputs template buffer
- """
- from . import ir
- return (
- isinstance(input_buf, ir.MultiOutput)
- and len(input_buf.inputs) == 1
- and is_multi_outputs_template(input_buf.inputs[0]) # type: ignore[arg-type]
- )
- def is_collective(
- node: Optional[Union[Node, Operation]],
- op: Optional[torch._ops.OperatorBase] = None,
- ) -> bool:
- if node is None:
- return False
- from . import ir
- return (
- isinstance(node, ir._CollectiveKernel)
- and not isinstance(node, ir._WaitKernel)
- and (op is None or node.op_overload is op)
- ) or (
- # TODO: this is a temporary solution to ensure that we can identify torchrec's
- # communication ops. But in order to allow better communication and computation
- # overlap, torchrec's communication ops should be not used.
- type(node) is ir.FallbackKernel
- and (
- # NOTE: the `hasattr()` check is to bypass errors such as the following:
- # AttributeError: '_OpNamespace' 'torchrec' object has no attribute 'all_to_all_single'
- (
- hasattr(torch.ops.torchrec, "all_to_all_single")
- and node.op_overload == torch.ops.torchrec.all_to_all_single.default
- )
- or (
- hasattr(torch.ops.torchrec, "all_gather_into_tensor")
- and node.op_overload
- == torch.ops.torchrec.all_gather_into_tensor.default
- )
- or (
- hasattr(torch.ops.torchrec, "reduce_scatter_tensor")
- and node.op_overload == torch.ops.torchrec.reduce_scatter_tensor.default
- )
- )
- )
- def is_wait(node: Optional[Union[IRNode, Operation]]) -> bool:
- from . import ir
- return type(node) is ir._WaitKernel
- def contains_collective(
- snode: BaseSchedulerNode,
- filter_fn: Optional[Callable[[BaseSchedulerNode], bool]] = None,
- ) -> bool:
- from torch._inductor.scheduler import GroupedSchedulerNode
- if isinstance(snode, GroupedSchedulerNode):
- return any(contains_collective(x) for x in snode.snodes)
- return is_collective(snode.node) and (filter_fn is None or filter_fn(snode))
- def contains_wait(snode: BaseSchedulerNode) -> bool:
- from torch._inductor.scheduler import GroupedSchedulerNode
- if isinstance(snode, GroupedSchedulerNode):
- return any(contains_wait(x) for x in snode.snodes)
- else:
- return is_wait(snode.node)
- def is_fallback_op(
- node: Optional[Operation],
- op: Union[torch._ops.OpOverload, Collection[torch._ops.OpOverload]],
- ) -> bool:
- from . import ir
- if isinstance(op, torch._ops.OpOverload):
- op = [op]
- return isinstance(node, ir.FallbackKernel) and node.op_overload in op
- def buf_name_to_fused_snode(
- buf_name: str, name_to_buf: dict[str, Any], name_to_fused_node: dict[str, Any]
- ) -> Any:
- return name_to_fused_node[name_to_buf[buf_name].defining_op.get_name()]
- def find_recursive_deps_of_node(
- snode: BaseSchedulerNode,
- collected_node_set: MutableSet[BaseSchedulerNode],
- name_to_buf: dict[str, SchedulerBuffer],
- name_to_fused_node: dict[str, BaseSchedulerNode],
- criteria_cb: Callable[[Any], bool] = lambda snode: False,
- ) -> None:
- if criteria_cb(snode):
- return
- collected_node_set.add(snode)
- for dep in snode.unmet_dependencies:
- defining_op_for_dep = buf_name_to_fused_snode(
- dep.name, name_to_buf, name_to_fused_node
- )
- if defining_op_for_dep in collected_node_set:
- continue
- find_recursive_deps_of_node(
- defining_op_for_dep,
- collected_node_set,
- name_to_buf,
- name_to_fused_node,
- criteria_cb=criteria_cb,
- )
- def find_recursive_users_of_node(
- snode: BaseSchedulerNode,
- collected_node_set: MutableSet[BaseSchedulerNode],
- name_to_buf: dict[str, SchedulerBuffer],
- name_to_fused_node: dict[str, BaseSchedulerNode],
- criteria_cb: Callable[[Any], bool] = lambda snode: False,
- ) -> None:
- if criteria_cb(snode):
- return
- collected_node_set.add(snode)
- for o in snode.get_outputs():
- for user in o.users:
- assert user.node is not None
- if user.node.get_name() == "OUTPUT":
- continue
- if user.node.get_name() not in name_to_fused_node:
- continue
- user_op = name_to_fused_node[user.node.get_name()]
- if user_op in collected_node_set:
- continue
- find_recursive_users_of_node(
- user_op,
- collected_node_set,
- name_to_buf,
- name_to_fused_node,
- criteria_cb=criteria_cb,
- )
- def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int) -> int:
- "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)"
- num_rng_seed_offset_inputs = (
- 2 if torch._functorch.config.functionalize_rng_ops else 0
- )
- # AOT won't lift any parameters if we're inlining NN Modules
- # however desugaring subclasses will still add arguments
- # resulted in extra fixed inputs https://github.com/pytorch/pytorch/issues/130502
- return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs
- def count_tangents(fx_g: torch.fx.GraphModule) -> int:
- """
- Infers which inputs are static for a backwards graph
- """
- def is_saved_tensor(x: Node) -> bool:
- return (
- "tangents" not in x.name
- and "bwd_seed" not in x.name
- and "bwd_base_offset" not in x.name
- and "bwd_rng_state" not in x.name
- )
- arg_count = 0
- static_arg_idxs = []
- for n in fx_g.graph.nodes:
- if n.op == "placeholder":
- if is_saved_tensor(n):
- static_arg_idxs.append(arg_count)
- arg_count += 1
- assert static_arg_idxs == list(range(len(static_arg_idxs)))
- return len(static_arg_idxs)
- @dataclasses.dataclass
- class BoxedBool:
- value: bool
- def __bool__(self) -> bool:
- return self.value
- @staticmethod
- def disable(obj: Any) -> Union[BoxedBool, bool]:
- if isinstance(obj, BoxedBool):
- obj.value = False
- return obj
- return False
- @contextlib.contextmanager
- def collect_defined_kernels(kernel_list: list[str]) -> Iterator[None]:
- from .codegen.wrapper import PythonWrapperCodegen
- orig_define_kernel = PythonWrapperCodegen.define_kernel
- def define_kernel(
- self: PythonWrapperCodegen,
- kernel_name: str,
- kernel_code: str,
- metadata: Optional[str] = None,
- gpu: bool = True,
- cpp_definition: Optional[str] = None,
- ) -> Any:
- kernel_list.append(kernel_code)
- return orig_define_kernel(
- self, kernel_name, kernel_code, metadata, gpu, cpp_definition
- )
- with mock.patch.object(PythonWrapperCodegen, "define_kernel", define_kernel):
- yield
- def get_cloned_parameter_buffer_name(name: str) -> str:
- return name + "__original__"
- def is_gpu(device: Optional[str]) -> bool:
- return device in GPU_TYPES
- def is_rocm() -> bool:
- """Check if we're running on ROCm/HIP platform."""
- return torch.version.hip is not None
- def device_need_guard(device: str) -> bool:
- return device != "mps" and is_gpu(device) # TODO: MPS does not expose streams now
- def needs_fallback_due_to_atomic_add_limitations(dtype: torch.dtype) -> bool:
- if dtype == torch.bfloat16 and torch.cuda.is_available():
- return torch.cuda.get_device_capability() < (9, 0)
- elif dtype == torch.bfloat16 and torch.xpu.is_available():
- return True
- else:
- return dtype in (torch.int64, torch.bool)
- def use_scatter_fallback(
- op_overload: torch._ops.OpOverload,
- reduction_type: Optional[str],
- self_dtype: torch.dtype,
- src_dtype: torch.dtype,
- src_device_type: str,
- src_is_tensor: bool,
- ) -> bool:
- if (
- op_overload.overloadpacket
- in (torch.ops.aten.scatter_reduce_, torch.ops.aten.scatter_reduce)
- and reduction_type is None
- ):
- return False
- reduce_ty = (
- "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum"
- )
- return (
- reduction_type not in (None, reduce_ty)
- or (
- src_is_tensor
- and is_gpu(src_device_type)
- and needs_fallback_due_to_atomic_add_limitations(src_dtype)
- )
- or (
- op_overload.overloadpacket == torch.ops.aten.scatter_reduce_
- and reduction_type == "sum"
- and src_is_tensor
- and src_device_type == "cpu"
- and config.cpp.fallback_scatter_reduce_sum
- and (config.cpp.dynamic_threads or parallel_num_threads() != 1)
- )
- or (reduction_type == reduce_ty and self_dtype in (torch.bool, torch.int64))
- or torch.are_deterministic_algorithms_enabled()
- )
- def dump_node_schedule(node_schedule: Sequence[BaseSchedulerNode]) -> None:
- """
- An API that can be used in pdb to dump a node_schedule.
- Right mainly dump the read/write dependencies but can add more as needed.
- """
- from torch._inductor.codegen.simd import DisableReduction, EnableReduction
- from torch._inductor.scheduler import SchedulerNode
- print(f"Node schedule with {len(node_schedule)} nodes")
- for idx, node in enumerate(node_schedule):
- print(f" {idx:3}:")
- # pyrefly: ignore [unnecessary-comparison]
- if node is EnableReduction:
- print("enable reduction")
- # pyrefly: ignore [unnecessary-comparison]
- elif node is DisableReduction:
- print("disable reduction")
- elif isinstance(node, SchedulerNode):
- is_red = node.is_reduction()
- print(f"{'red' if is_red else 'pw'} scheduler node")
- if is_red:
- assert node.node is not None
- print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined]
- print("ReadDep:")
- for dep in node.read_writes.reads:
- print(dep)
- print("WriteDep:")
- for dep in node.read_writes.writes:
- print(dep)
- else:
- raise RuntimeError(f"Unrecognized node type: {type(node)}")
- def tensor_is_aligned(tensor: torch.Tensor) -> bool:
- # See Note: [Input Alignment handling in Inductor]
- # Right now, we don't try to guard on the alignment of the storage offset.
- # When this comment was written, non-symbolic storage_offsets are not guarded on
- # but symbolic storage_offsets are. For consistency, we suppress guard creation
- # upon performing this check: that ensures that we don't add recompiles when we
- # add this logic.
- from torch.fx.experimental.symbolic_shapes import statically_known_true
- return statically_known_true(
- (tensor.storage_offset() * get_dtype_size(tensor.dtype)) % GPU_ALIGN_BYTES == 0
- )
- def should_assume_input_aligned(example_input: torch.Tensor) -> bool:
- # See Note: [Input Alignment handling in Inductor]
- # right now, we only care about alignment for cuda tensors.
- if not is_gpu(example_input.device.type):
- return False
- return config.assume_aligned_inputs or tensor_is_aligned(example_input)
- def maybe_get_suppress_shape_guards_ctx() -> contextlib.AbstractContextManager[None]:
- # Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards()
- # If it's not available, return a nullcontext.
- # If we're dealing with cudagraphs, we might not have a tracing_context
- tracing_context = torch._guards.TracingContext.try_get()
- if not tracing_context:
- return contextlib.nullcontext()
- # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode
- if not tracing_context.fake_mode or not tracing_context.fake_mode.shape_env:
- return contextlib.nullcontext()
- shape_env = tracing_context.fake_mode.shape_env
- return shape_env.suppress_guards()
- def run_and_get_cpp_code(
- fn: Callable[P, _T], *args: P.args, **kwargs: P.kwargs
- ) -> tuple[_T, str]:
- # We use the patch context manager instead of using it as a decorator.
- # In this way, we can ensure that the attribute is patched and unpatched correctly
- # even if this run_and_get_cpp_code function is called multiple times.
- with unittest.mock.patch.object(config, "debug", True):
- torch._dynamo.reset()
- import io
- import logging
- log_capture_string = io.StringIO()
- ch = logging.StreamHandler(log_capture_string)
- from torch._inductor.codecache import output_code_log
- output_code_log.addHandler(ch)
- prev_level = output_code_log.level
- output_code_log.setLevel(logging.DEBUG)
- result = fn(*args, **kwargs)
- s = log_capture_string.getvalue()
- output_code_log.setLevel(prev_level)
- output_code_log.removeHandler(ch)
- return result, s
- def shape_env_from_inputs(inputs: Sequence[InputType]) -> Optional[ShapeEnv]:
- fake_mode = detect_fake_mode(inputs)
- # TODO(voz): It would be nice to enable this assert, but there are lots of tests that
- # pass in real inputs for now.
- # if len(inputs) > 0:
- # assert fake_mode is not None, breakpoint()
- if fake_mode is not None:
- return fake_mode.shape_env
- # When there are no tensor inputs, get shape_env from the first SymInt.
- for input in inputs:
- if isinstance(input, torch.SymInt):
- return input.node.shape_env
- # Check tensor sizes and strides for SymInt values
- if isinstance(input, torch.Tensor):
- for size in input.size():
- if isinstance(size, torch.SymInt):
- return size.node.shape_env
- for stride in input.stride():
- if isinstance(stride, torch.SymInt):
- return stride.node.shape_env
- # TODO(voz): Should we always have one anyway?
- return None
- def align_inputs_from_check_idxs(
- model: Callable[[list[InputType]], _T],
- inputs_to_check: Sequence[int],
- mutated_input_idxs: OrderedSet[int],
- ) -> Callable[[list[InputType]], _T]:
- if len(inputs_to_check) == 0:
- return model
- def run(new_inputs: list[InputType]) -> Any:
- old_tensors, new_tensors = copy_misaligned_inputs(
- new_inputs, inputs_to_check, mutated_input_idxs
- )
- out = model(new_inputs)
- # If a mutated tensor was cloned to be aligned, we need to reflect back the mutation to the
- # original tensor.
- if len(old_tensors):
- torch._foreach_copy_(old_tensors, new_tensors)
- return out
- return run
- def clone_preserve_strides(x: torch.Tensor) -> torch.Tensor:
- if 0 in x.size():
- # Short-circuits if the shape has no elements
- needed_size = 0
- else:
- needed_size = (
- sum((shape - 1) * stride for shape, stride in zip(x.size(), x.stride())) + 1
- )
- buffer = torch.as_strided(x, (needed_size,), (1,)).clone()
- return torch.as_strided(buffer, x.size(), x.stride())
- def copy_misaligned_inputs(
- new_inputs: list[InputType],
- check_inputs_idxs: Sequence[int],
- return_pair_idxs: Optional[OrderedSet[int]] = None,
- ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
- """
- Clones misaligned tensors which we inferred were aligned. Returns a tuple of [old_tensors], [new_tensors] for every
- cloned tensor which is in `return_pair_idxs`.
- """
- old_tensors: list[torch.Tensor] = []
- new_tensors: list[torch.Tensor] = []
- # hoist above loop because this is on the hot path
- ret_pair_defined = return_pair_idxs is not None
- for i in check_inputs_idxs:
- _inp = new_inputs[i]
- assert isinstance(_inp, torch.Tensor), (
- f"Expected tensors only, but got: {type(_inp)}"
- )
- if _inp.data_ptr() % ALIGNMENT:
- new_inputs[i] = clone_preserve_strides(_inp)
- if ret_pair_defined and i in return_pair_idxs: # type: ignore[operator]
- old_tensors.append(_inp)
- new_tensors.append(new_inputs[i]) # type: ignore[arg-type]
- return old_tensors, new_tensors
- def remove_unaligned_input_idxs(
- inputs: Sequence[InputType],
- static_input_idxs: Sequence[int],
- ) -> Sequence[int]:
- """
- We require all inputs to be aligned, so introduce a copy for any
- that aren't.
- """
- aligned_static_input_idxs = []
- for idx in static_input_idxs:
- input = inputs[idx]
- if isinstance(input, torch.Tensor) and (input.data_ptr() % ALIGNMENT) == 0:
- aligned_static_input_idxs.append(idx)
- if len(aligned_static_input_idxs) != len(static_input_idxs):
- return aligned_static_input_idxs
- return static_input_idxs
- def expr_fits_within_32bit(e: sympy.Expr) -> bool:
- from .virtualized import V
- int_max = torch.iinfo(torch.int32).max
- size_hint = V.graph.sizevars.size_hint
- has_hint = V.graph.sizevars.shape_env.has_hint
- if config.assume_32bit_indexing:
- V.graph.sizevars.check_leq(e, int_max) # type: ignore[arg-type]
- return True
- # Allow for unhinted e as long as we can still statically prove
- # (e.g., via ValueRanges) that it is still in bounds
- if V.graph.sizevars.statically_known_true(e <= int_max):
- return True
- # AOTI doesn't guard on < 2**32, so checking hints isn't a viable option,
- # in case the hinted value is < 2**32, but the allowed range is larger.
- # However, to prevent possible perf regressions on pre-existing AOTI models
- # which don't set an upper bound on the valid range, we'll skip the check.
- # To recap:
- # - If using AOTI:
- # - If allowed range has no upper bound, then check the hint to determine
- # whether this fits in int32
- # - If allowed range does have an upper bound, then obey the upper bound
- # (check whether upper bound < int32_max) without checking the hint.
- if V.aot_compilation:
- # check whether value has an upper bound (1e20 is > INT64_MAX, assume
- # there is no upper bound if it can be larger than 1e20)
- if V.graph.sizevars.statically_known_true(e < 1e20):
- # if so, then assume int_max < upper bound < inf
- # so this could potentially have int64 values
- return False
- # Otherwise, the hint MUST exist and be in range
- return has_hint(e) and size_hint(e) <= int_max
- def set_tracing_context_output_strides(
- example_inputs: Sequence[Any], compiled_graph: CompiledFxGraph
- ) -> None:
- # Return the output strides to the caller via TracingContext
- context = torch._guards.TracingContext.try_get()
- if context is not None and context.output_strides is not None:
- assert len(context.output_strides) == 0
- shape_env = shape_env_from_inputs(example_inputs)
- assert compiled_graph.output_strides is not None
- for exprs in compiled_graph.output_strides:
- if exprs is None:
- context.output_strides.append(None)
- else:
- fakify_first_call = False
- if ctx := torch._guards.TracingContext.try_get():
- fakify_first_call = ctx.fakify_first_call
- def map_expr(e: Any) -> Union[float, int, SymInt, SymFloat, SymBool]:
- if shape_env is None:
- return int(e)
- if fakify_first_call:
- return shape_env.deserialize_symexpr(e)
- return shape_env.evaluate_symexpr(e)
- context.output_strides.append(
- tuple(map_expr(e) for e in exprs) # type: ignore[misc]
- )
- def should_use_remote_fx_graph_cache() -> bool:
- if config.fx_graph_remote_cache is not None:
- return config.fx_graph_remote_cache
- if not config.is_fbcode():
- return False
- if torch._utils_internal.is_fb_unit_test():
- return False
- try:
- from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
- except ModuleNotFoundError:
- return False
- return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(
- "pytorch/remote_cache:fx_graph_memcache_version"
- )
- def normalize_name(name: str) -> str:
- return re.sub(r"[^a-zA-Z0-9_]", "_", name)
- # correct cases where Triton types names don't match PyTorch
- _triton_type_mapping = {
- "tl.bool": "tl.int1",
- "tl.float8_e4m3fn": "tl.float8e4nv",
- "tl.float8_e5m2": "tl.float8e5",
- "tl.float8_e4m3fnuz": "tl.float8e4b8",
- "tl.float8_e5m2fnuz": "tl.float8e5b16",
- # TODO: remove when support is added in triton
- # https://github.com/triton-lang/triton/issues/6054
- "tl.float8_e8m0fnu": "tl.uint8",
- "tl.float4_e2m1fn_x2": "tl.uint8",
- }
- _torch_triton_mapping = {v: k for k, v in _triton_type_mapping.items()}
- _triton_type_re = re.compile(r"^.*[.]")
- def triton_type(dtype: torch.dtype) -> str:
- """Convert torch.dtype to triton type"""
- triton_type_name = _triton_type_re.sub("tl.", str(dtype))
- return _triton_type_mapping.get(triton_type_name, triton_type_name)
- def triton_type_to_torch(dtype: str) -> torch.dtype:
- adjusted_type = _torch_triton_mapping.get(dtype, dtype)
- type_name = adjusted_type.replace("tl.", "")
- out_dtype = getattr(torch, type_name)
- assert isinstance(out_dtype, torch.dtype)
- return out_dtype
- def is_same_tensor(data: torch.Tensor, value: torch.Tensor) -> bool:
- return (
- not data.is_mkldnn
- and data.size() == value.size()
- and data.stride() == value.stride()
- and data.dtype == value.dtype
- and data.device == value.device
- and data.untyped_storage().data_ptr() == value.untyped_storage().data_ptr()
- and data.storage_offset() == value.storage_offset()
- )
- def is_same_mkldnn_tensor(data: torch.Tensor, value: torch.Tensor) -> bool:
- return (
- data.is_mkldnn
- and data.size() == value.size()
- and data.dtype == value.dtype
- and data.device == value.device
- and torch.ops.mkldnn.data_ptr(data) == torch.ops.mkldnn.data_ptr(value)
- )
- @functools.cache
- def boolean_ops() -> tuple[str, ...]:
- return (
- "isinf",
- "isnan",
- "logical_not",
- "logical_and",
- "signbit",
- "and_",
- "le",
- "lt",
- "ge",
- "gt",
- "eq",
- "ne",
- "or_", # TODO should remove this op
- "xor",
- )
- @dataclasses.dataclass
- class OpDtypeRule:
- type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND
- override_return_dtype: Optional[torch.dtype]
- op_dtype_propagation_rules: dict[str, OpDtypeRule] = {}
- def register_op_dtype_propagation_rules(
- name: str,
- type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
- override_return_dtype: Optional[torch.dtype],
- ) -> None:
- op_dtype_propagation_rules[name] = OpDtypeRule(
- type_promotion_kind, override_return_dtype
- )
- op_requires_libdevice_fp64: OrderedSet[str] = OrderedSet()
- def register_op_requires_libdevice_fp64(name: str) -> None:
- op_requires_libdevice_fp64.add(name)
- def get_current_backend(device_type: Optional[str] = None) -> str:
- from torch._inductor.virtualized import V
- if not device_type:
- device_type = V.graph.get_current_device_or_throw().type
- if device_type == "cpu":
- return config.cpu_backend
- elif device_type == "mps":
- return "mps"
- elif device_type == "xpu":
- return config.xpu_backend
- else:
- return config.cuda_backend
- def upcast_compute_type(dtype: torch.dtype) -> torch.dtype:
- """Maybe upcast [b]float16 to float32"""
- if (
- dtype in (torch.float16, torch.bfloat16)
- and config.triton.codegen_upcast_to_fp32
- and get_current_backend() == "triton"
- ):
- return torch.float32
- return dtype
- KeyType = TypeVar("KeyType")
- ValType = TypeVar("ValType")
- class ScopedDict(MutableMapping[KeyType, ValType]):
- """
- A dictionary-like object that allows for scoped updates. It maintains
- an original dictionary and a set of new items that can override
- the original items within the scope. The original dictionary is
- unmodified.
- """
- def __init__(self, original_dict: Mapping[KeyType, ValType]):
- self.original_dict = original_dict
- self.new_items: dict[KeyType, ValType] = {}
- def __getitem__(self, key: KeyType) -> ValType:
- if key in self.new_items:
- return self.new_items[key]
- return self.original_dict[key]
- def __setitem__(self, key: KeyType, value: ValType) -> None:
- self.new_items[key] = value
- def __contains__(self, key: object) -> bool:
- return key in self.new_items or key in self.original_dict
- def get(self, key: KeyType, default: Optional[ValType] = None) -> Optional[ValType]: # type: ignore[override]
- if key in self.new_items:
- return self.new_items[key]
- return self.original_dict.get(key, default)
- def __len__(self) -> int:
- n = len(self.original_dict)
- for k in self.new_items:
- if k not in self.original_dict:
- n += 1
- return n
- def __iter__(self) -> Iterator[KeyType]:
- yield from self.original_dict
- for k in self.new_items:
- if k not in self.original_dict:
- yield k
- def __bool__(self) -> bool:
- return bool(self.original_dict or self.new_items)
- def __delitem__(self, key: KeyType) -> None:
- raise NotImplementedError
- @dataclass_transform(frozen_default=True)
- def ir_dataclass(cls: Optional[type[Any]] = None, /, *, frozen: bool = True) -> Any:
- def wrap(cls: _T) -> _T:
- return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload]
- if cls is None:
- return wrap
- return wrap(cls)
- def get_donated_idxs() -> Optional[list[int]]:
- tracing_context = torch._guards.TracingContext.try_get()
- if tracing_context is not None and tracing_context.fw_metadata:
- return tracing_context.fw_metadata.bw_donated_idxs
- return None
- class TritonAttrsDescriptorVersion(enum.Enum):
- V0_NO_TRITON = 0
- V1_COMPILER = 1 # triton.compiler.compiler.AttrsDescriptor
- V2_BACKENDS = 2 # triton.backends.compiler.AttrsDescriptor
- V3_BACKENDS_TUPLE = (
- 3 # triton.backends.compiler.AttrsDescriptor, but with tuple support
- )
- V4_DICT = 4 # a raw dict
- @functools.cache
- def get_triton_attrs_descriptor_version() -> TritonAttrsDescriptorVersion:
- if importlib.util.find_spec("triton") is None:
- return TritonAttrsDescriptorVersion.V0_NO_TRITON
- import triton.backends.compiler
- import triton.compiler.compiler
- if hasattr(triton.backends.compiler, "AttrsDescriptor"):
- # Triton 3.2.0
- # AttrsDescriptor was moved from triton.compiler.compiler to triton.backends.compiler.
- # AttrsDescriptor and its serialization format were also changed.
- # TODO: implement V3_BACKENDS_TUPLE
- # On Dec 9, 2024, tuple support (triton #5220) was implemented and breaks handling.
- # We don't have a way to detect this (and haven't implemented this version)
- return TritonAttrsDescriptorVersion.V2_BACKENDS
- elif hasattr(triton.compiler.compiler, "AttrsDescriptor"):
- # Triton 3.0.0
- return TritonAttrsDescriptorVersion.V1_COMPILER
- else:
- # After Jan 1, 2025
- # AttrsDescriptor was removed and replaced with a raw dict.
- return TritonAttrsDescriptorVersion.V4_DICT
- def triton_version_uses_attrs_dict() -> bool:
- return get_triton_attrs_descriptor_version() == TritonAttrsDescriptorVersion.V4_DICT
- def get_op_names(op: torch._ops.OperatorBase) -> tuple[str, str]:
- op_overload_packet_name: str = op.name()
- op_overload_name = (
- f"{op_overload_packet_name}.{op._overloadname}"
- if isinstance(op, torch._ops.OpOverload)
- else op_overload_packet_name
- )
- return op_overload_packet_name, op_overload_name
- def _fx_node_is_input_dependent_cudagraph_unsafe(fx_node: torch.fx.Node) -> bool:
- """
- Check if an FX node is cudagraph-unsafe based on its input arguments.
- Some ops are only cudagraph-unsafe depending on their inputs (e.g., index_put
- with boolean indices triggers .nonzero() during capture, but integer indices
- are safe).
- """
- from torch.fx.operator_schemas import normalize_function
- target = fx_node.target
- if not isinstance(target, torch._ops.OpOverload):
- return False
- # index_put with boolean indices triggers .nonzero() during capture
- if target in (
- torch.ops.aten.index_put.default,
- torch.ops.aten.index_put_.default,
- torch.ops.aten._unsafe_index_put.default,
- ):
- normalized = normalize_function(
- target, fx_node.args, fx_node.kwargs, normalize_to_only_use_kwargs=True
- )
- if normalized is not None:
- _, kwargs = normalized
- indices = kwargs["indices"]
- for idx in indices:
- if idx is not None and idx.meta["val"].dtype in (
- torch.bool,
- torch.uint8,
- ):
- return True
- return False
- def is_cudagraph_unsafe_fx_node(fx_node: torch.fx.Node) -> bool:
- """
- Check if an FX node is cudagraph-unsafe.
- This includes:
- - Ops in FORBIDDEN_CUDAGRAPH_OPS (CPU sync, dynamic alloc, etc.)
- - Ops with the cudagraph_unsafe tag
- - Input-dependent unsafe ops (e.g., index_put with boolean indices)
- - Ops with sparse tensor outputs
- """
- target = fx_node.target
- # Check against the forbidden ops set
- if str(target) in FORBIDDEN_CUDAGRAPH_OPS:
- return True
- # Check for cudagraph_unsafe tag
- if (
- isinstance(target, torch._ops.OpOverload)
- and torch._C.Tag.cudagraph_unsafe in target.tags # type: ignore[attr-defined]
- ):
- return True
- # Check for input-dependent unsafety
- if _fx_node_is_input_dependent_cudagraph_unsafe(fx_node):
- return True
- # Check for sparse tensor outputs
- if (val := fx_node.meta.get("val")) is not None:
- vals = [val] if not isinstance(val, (list, tuple)) else val
- for v in vals:
- if isinstance(v, torch.Tensor) and v.is_sparse:
- return True
- return False
- def is_cudagraph_unsafe_op(node: Operation) -> bool:
- """
- Returns True if the node is an op that is not cudagraphable.
- This includes:
- - Ops in FORBIDDEN_CUDAGRAPH_OPS (CPU sync, dynamic alloc, etc.)
- - Ops with the cudagraph_unsafe tag
- - index_put_ with boolean indices (triggers .nonzero() during capture)
- - Control flow nodes (Conditional, WhileLoop)
- - Ops with sparse tensor outputs
- """
- from . import ir
- # Control flow nodes are cudagraph-unsafe
- if isinstance(node, (ir.Conditional, ir.WhileLoop)):
- return True
- if not isinstance(node, (ir.FallbackKernel, ir.ExternKernel)):
- return False
- fx_node = getattr(node, "fx_node", None)
- if fx_node is not None and is_cudagraph_unsafe_fx_node(fx_node):
- return True
- return False
- def get_ld_library_path() -> str:
- path = os.environ.get("LD_LIBRARY_PATH", "")
- if config.is_fbcode():
- from libfb.py.parutil import get_runtime_path
- runtime_path = get_runtime_path()
- if runtime_path:
- lib_path = os.path.join(runtime_path, "runtime", "lib")
- path = os.pathsep.join([lib_path, path]) if path else lib_path
- return path
- def is_codegen_graph_partition_subgraph(wrapper: PythonWrapperCodegen) -> bool:
- from torch._inductor.codegen.wrapper import SubgraphPythonWrapperCodegen
- return (
- isinstance(wrapper, SubgraphPythonWrapperCodegen)
- and wrapper.partition_signatures is not None
- )
- def is_using_cudagraph_partition() -> bool:
- return (
- torch._inductor.config.triton.cudagraphs
- or _unstable_customized_partition_wrapper.wrapper is not None
- ) and torch._inductor.config.graph_partition
- def dtype_from_size(size: int) -> torch.dtype:
- from .virtualized import V
- if V.graph.sizevars.statically_known_lt(
- size, 2**31
- ) and V.graph.sizevars.statically_known_geq(size, -(2**31)):
- return torch.int32
- else:
- return torch.int64
- SUPPORTED_MKLDNN_DEVICES = ("cpu", "xpu")
- def is_mkldnn_bf16_supported(device_type: str) -> bool:
- """
- Returns True if the device supports MKL-DNN BF16.
- """
- if device_type == "cpu":
- return torch.ops.mkldnn._is_mkldnn_bf16_supported()
- elif "xpu" in device_type:
- # match "xpu", "xpu:0", "xpu:1", etc.
- return True
- return False
- def is_mkldnn_fp16_supported(device_type: str) -> bool:
- """
- Returns True if the device supports MKL-DNN FP16.
- """
- if device_type == "cpu":
- return torch.ops.mkldnn._is_mkldnn_fp16_supported()
- elif "xpu" in device_type:
- # match "xpu", "xpu:0", "xpu:1", etc.
- return True
- return False
- def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str:
- widths = [len(str(e)) for e in headers]
- for row in elements:
- assert len(row) == len(headers)
- for i, e in enumerate(row):
- widths[i] = max(widths[i], len(str(e)))
- lines = []
- lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths)))
- # widths whitespace horizontal separators
- total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1)
- lines.append("-" * total_width)
- for row in elements:
- lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths)))
- return "\n".join(lines)
- def zip_dicts(
- dict1: Mapping[KeyType, ValType],
- dict2: Mapping[KeyType, ValType],
- d1_default: ValType | None = None,
- d2_default: ValType | None = None,
- ) -> Generator[tuple[KeyType, ValType | None, ValType | None], None, None]:
- """
- Zip two dictionaries together, replacing missing keys with default values.
- Args:
- dict1 (dict): The first dictionary.
- dict2 (dict): The second dictionary.
- d1_default (Any): the default value for the first dictionary
- d2_default (Any): the default value for the second dictionary
- Yields:
- tuple: A tuple containing the key, the value from dict1 (or d1_default if missing),
- and the value from dict2 (or d2_default if missing).
- """
- # Find the union of all keys
- all_keys = OrderedSet(dict1.keys()) | OrderedSet(dict2.keys())
- # Iterate over all keys
- for key in all_keys:
- # Get the values from both dictionaries, or default if missing
- value1 = dict1.get(key)
- value2 = dict2.get(key)
- yield (
- key,
- value1 if value1 is not None else d1_default,
- value2 if value2 is not None else d2_default,
- )
- def maybe_aoti_standalone_config(config_patches: dict[str, Any]) -> dict[str, Any]:
- """
- Ensures the configuration is internally consistent for standalone AOTInductor.
- If `aot_inductor_mode.compile_standalone` is set to True in the provided
- `config_patches` (or falls back to the global config), this function ensures
- that the following configs are also enabled:
- - `aot_inductor.package_cpp_only`
- Args:
- config_patches (dict[str, Any]): A dictionary of user-provided config
- overrides for AOTInductor compilation.
- Returns:
- dict[str, Any]: The possibly-updated `config_patches` dictionary.
- """
- def patch_config(
- config_patches: dict[str, Any], config_name: str, config_value: Any
- ) -> None:
- value = config_patches.get(config_name, getattr(config, config_name))
- if value is None:
- config_patches[config_name] = config_value
- elif not value and value != config_value:
- raise RuntimeError(
- f"Invalid config: {config_name}={config_value} when aot_inductor_mode.compile_standalone is True."
- )
- def force_patch_config(
- config_patches: dict[str, Any], config_name: str, config_value: Any
- ) -> None:
- value = config_patches.get(config_name, getattr(config, config_name))
- if value != config_value:
- log.warning(
- "Overriding: %s=%s when aot_inductor_mode.compile_standalone is True.",
- config_name,
- config_value,
- )
- config_patches[config_name] = config_value
- compile_standalone = config_patches.get(
- "aot_inductor_mode.compile_standalone",
- config.aot_inductor_mode.compile_standalone,
- )
- # Make a copy of the config_patches to avoid modifying the original dictionary, needed for testing
- config_patches = config_patches.copy()
- if compile_standalone:
- # Standlaone AOTInductor means only generate cpp project for building a standalone binary
- patch_config(config_patches, "aot_inductor.package_cpp_only", True)
- # Standlaone AOTInductor needs to embed the kernel code in the binary
- patch_config(config_patches, "aot_inductor.embed_kernel_binary", True)
- # Default to use multi-arch kernel codegen for non-rocm GPU
- patch_config(
- config_patches, "aot_inductor.emit_multi_arch_kernel", not torch.version.hip
- )
- patch_config(
- config_patches, "aot_inductor.model_name_for_generated_files", "aoti_model"
- )
- # TODO: change these two configs to default to None and use patch_config
- force_patch_config(
- config_patches,
- "aot_inductor.link_libtorch",
- config.test_configs.use_libtorch,
- )
- force_patch_config(config_patches, "aot_inductor.dynamic_linkage", False)
- cross_target_platform = config_patches.get(
- "aot_inductor.cross_target_platform",
- config.aot_inductor.cross_target_platform,
- )
- package_constants_in_so = config_patches.get(
- "aot_inductor.package_constants_in_so",
- config.aot_inductor.package_constants_in_so,
- )
- if cross_target_platform == "windows" and package_constants_in_so:
- raise RuntimeError(
- "config.aot_inductor.package_constants_in_so is not supported for windows cross-compilation. "
- "Please use config.aot_inductor.package_constants_on_disk_format = binary_blob."
- )
- return config_patches
- def determine_aoti_mmap_flags(consts_size: int) -> tuple[bool, bool]:
- """
- Decide whether we should mmap weights, and whether to store the weights with .so.
- If force_mmap_weights or package_constants_on_disk_format == "binary_blob" configs are set, respect the config.
- Returns tuple (use_external_weights, use_mmap_weights).
- """
- if (
- config.aot_inductor.force_mmap_weights
- and config.aot_inductor.package_constants_on_disk_format == "binary_blob"
- ):
- raise RuntimeError(
- "config.aot_inductor.package_constants_on_disk_format = binary_blob and "
- "config.aot_inductor.force_mmap_weights cannot both be True."
- )
- if config.aot_inductor.force_mmap_weights:
- if config.aot_inductor.cross_target_platform == "windows":
- raise RuntimeError(
- "when cross_target_platform is windows, use_mmap_weights should not be true."
- )
- use_mmap_weights = True
- use_external_weights = False
- return use_external_weights, use_mmap_weights
- if config.aot_inductor.package_constants_on_disk_format == "binary_blob":
- use_external_weights = True
- use_mmap_weights = False
- return use_external_weights, use_mmap_weights
- if consts_size <= 2_000_000_000:
- return False, False
- use_external_weights = False
- use_mmap_weights = not config.is_fbcode()
- return use_external_weights, use_mmap_weights
- def is_valid_aoti_model_name() -> bool:
- """
- Validates if a model name is suitable for use in code generation.
- """
- from torch._inductor import config
- model_name = config.aot_inductor.model_name_for_generated_files
- if model_name is None:
- return True
- if not isinstance(model_name, str):
- raise ValueError("Invalid AOTI model name: Model name must be a string")
- if model_name == "":
- return True
- # Can only contain alphanumeric characters and underscores
- if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", model_name):
- raise ValueError(
- "Invalid AOTI model name: Model name can only contain letters, numbers, and underscores"
- )
- return True
- def get_free_symbols(x: IterateExprs, unbacked_only: bool) -> OrderedSet[sympy.Symbol]:
- if unbacked_only:
- return free_unbacked_symbols(x)
- else:
- return free_symbols(x)
- def python_subprocess_env() -> dict[str, str]:
- """
- Get a base environment for running Python subprocesses.
- """
- env = {
- # Inherit the environment of the current process.
- **os.environ,
- # Set the PYTHONPATH so the subprocess can find torch.
- "PYTHONPATH": os.environ.get(
- "TORCH_CUSTOM_PYTHONPATH", os.pathsep.join(sys.path)
- ),
- }
- # Set PYTHONHOME for internal builds, to account for builds that bundle the
- # runtime. Otherwise they will use the libraries and headers from the
- # platform runtime instead.
- #
- # This can't be done for external builds. The process can be run from a
- # venv and that won't include Python headers. The process needs to be able
- # to search for and find the platform runtime.
- if config.is_fbcode():
- env["PYTHONHOME"] = sysconfig.get_path("data")
- return env
- @dataclasses.dataclass(frozen=True)
- class CUDAGraphWrapperMetadata:
- """
- Metadata for Customized CUDAGraphWrapper.
- Currently assumes there is 1 dynamo graph and will extend to
- multiple graphs in the future.
- """
- # The number of partitions that are cudagraphable.
- num_partitions: int
- # Index of the current partition.
- partition_index: int
- PartitionFnType = Callable[..., Any]
- CUDAGraphWrapperType = Callable[
- [PartitionFnType, CUDAGraphWrapperMetadata], PartitionFnType
- ]
- # only incremented by user call of mark_step_begin
- class CUDAGraphWrapper:
- wrapper: Optional[CUDAGraphWrapperType] = None
- # A customized partition wrappers from users. Interface should be:
- #
- # def wrapper(fn: PartitionFnType, metadata: CUDAGraphWrapperMetadata) -> PartitionFnType
- #
- # Inductor generates N wrapper functions for N partition functions, and mechanically wrap
- # each partition fn with the generated wrapper function. Users need to handle all details
- # such as static inputs, dynamic shapes, etc.
- # Users could customize the wrapper based on the metadata. One example is to have special
- # handle for the first and last wrapper function.
- #
- # Warning: This API is unstable and may change in the future.
- _unstable_customized_partition_wrapper = CUDAGraphWrapper()
- def set_customized_partition_wrappers(wrapper: CUDAGraphWrapperType) -> None:
- _unstable_customized_partition_wrapper.wrapper = wrapper
- def snode_args_kwargs(snode: BaseSchedulerNode) -> tuple[list[Any], dict[str, Any]]:
- args = snode.node.inputs # type: ignore[union-attr]
- args = snode.node.fill_non_provided_args( # type: ignore[union-attr]
- [*args, *snode.node.constant_args], # type: ignore[union-attr]
- snode.node.kwargs, # type: ignore[union-attr]
- )
- kwargs = snode.node.kwargs # type: ignore[union-attr]
- flat_args, flat_args_pytree_spec = pytree.tree_flatten((args, kwargs))
- def _is_tensor_ir(x) -> bool: # type: ignore[no-untyped-def]
- return isinstance(x, torch._inductor.ir.IRNode) and not isinstance(
- x, torch._inductor.ir.GeneratorState
- )
- flat_args = [
- torch._inductor.ir.ir_node_to_tensor(a, guard_shape=False)
- if _is_tensor_ir(a)
- else a
- for a in flat_args
- ]
- def _tensor(size, dtype, device) -> torch.Tensor: # type: ignore[no-untyped-def]
- return torch.empty(size, dtype=dtype, device=device)
- def to_real_tensor(e: Any) -> Any:
- if not isinstance(e, torch.Tensor):
- return e
- out = _tensor(e.size(), e.dtype, e.device)
- return out
- flat_args = [to_real_tensor(a) for a in flat_args]
- args, kwargs = pytree.tree_unflatten(flat_args, flat_args_pytree_spec)
- return args, kwargs
- def is_nonfreeable_buffers(dep: Dep) -> bool:
- from .virtualized import V
- dep_name = dep.name
- # Subgraphs have a prefix for the name, cleanup the prefix
- # before checking for known strings.
- if V.graph.name:
- dep_name = dep_name.removeprefix(V.graph.name + "_")
- return dep_name.startswith(
- ("primals_", "arg", "fwd_rng_state", "bwd_rng_state", "tangents")
- )
- # Make sure to also include your jinja templates within torch_package_data in setup.py, or this function won't be able to find them
- def load_template(name: str, template_dir: Path) -> str:
- """Load a template file and return its content."""
- with open(template_dir / f"{name}.py.jinja") as f:
- return f.read()
- def should_fallback_by_default(node: torch.fx.Node) -> bool:
- """Decide whether fallback for a node. This is only used in inductor lite mode."""
- target = node.target
- assert isinstance(
- target, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
- ), f"Expected OpOverload or HigherOrderOperator, but found {type(target)}"
- if not config.fallback_by_default:
- return False
- # some ops need special handle due to dynamic shapes. we can avoid
- # fallback if they do not impact numerics.
- skip_fallback_due_to_dynamic_shape = OrderedSet(
- [
- torch.ops.aten._assert_scalar.default,
- torch.ops.aten.lift_fresh_copy.default,
- ]
- )
- if target in skip_fallback_due_to_dynamic_shape:
- return False
- # Most hops have registered lowering. We should follow the lowering and not fallback.
- # However, in rare cases, hops may not register lowering, such as
- # torch.ops.higher_order.triton_kernel_wrapper_functional. We should fallback for
- # these hops.
- fallback_hops = OrderedSet(
- [torch.ops.higher_order.triton_kernel_wrapper_functional]
- )
- if isinstance(target, torch._ops.HigherOrderOperator):
- return target in fallback_hops
- return not _needs_inductor_compile(node)
- # Collective operation names for specialized benchmarking
- COLLECTIVE_OPS = OrderedSet(
- [
- "torch.ops._c10d_functional.all_reduce.default",
- "torch.ops._c10d_functional.all_reduce_.default",
- "torch.ops._c10d_functional.all_gather_into_tensor.default",
- "torch.ops._c10d_functional.reduce_scatter_tensor.default",
- "torch.ops._c10d_functional.all_to_all_single.default",
- "torch.ops._c10d_functional_autograd.all_reduce.default",
- "torch.ops._c10d_functional_autograd.all_gather_into_tensor.default",
- "torch.ops._c10d_functional_autograd.reduce_scatter_tensor.default",
- "torch.ops._c10d_functional_autograd.all_to_all_single.default",
- ]
- )
- def is_collective_op(op_name: str) -> bool:
- """Check if an operation is a collective operation."""
- return op_name in COLLECTIVE_OPS
- @lru_cache
- def tlx_only_cuda_options() -> list[str]:
- if config.is_fbcode():
- try:
- from torch._inductor.fb.tlx_templates.registry import tlx_only_cuda_options
- return tlx_only_cuda_options
- except ImportError:
- return []
- else:
- return []
- def _round_up(x: int, y: int) -> int:
- """Round x up to the nearest multiple of y."""
- return ((x + y - 1) // y) * y
- def _infer_scale_swizzle_impl(
- mat_size: tuple[Any, Any],
- scale_size: tuple[Any, ...],
- scale_numel: Any,
- mat_dtype: torch.dtype,
- scale_dtype: torch.dtype,
- eq_fn: Callable[[Any, Any], bool],
- ) -> tuple[Optional[Any], Optional[Any]]:
- """
- Core implementation for scale/swizzle inference.
- """
- from torch.nn.functional import ScalingType, SwizzleType
- # Tensor-wise: single scale for entire tensor
- if eq_fn(scale_numel, 1):
- return ScalingType.TensorWise, SwizzleType.NO_SWIZZLE
- # Row-wise: one scale per row or column
- if len(scale_size) >= 2:
- if (eq_fn(scale_size[0], mat_size[0]) and eq_fn(scale_size[1], 1)) or (
- eq_fn(scale_size[0], 1) and eq_fn(scale_size[1], mat_size[1])
- ):
- return ScalingType.RowWise, SwizzleType.NO_SWIZZLE
- # Block-wise 1x128 / 128x1 (DeepGemm style)
- if (
- eq_fn(scale_size[0], mat_size[0])
- and eq_fn(scale_size[1], ceildiv(mat_size[1], 128))
- ) or (
- eq_fn(scale_size[1], mat_size[1])
- and eq_fn(scale_size[0], ceildiv(mat_size[0], 128))
- ):
- return ScalingType.BlockWise1x128, SwizzleType.NO_SWIZZLE
- # Block-wise 128x128
- if eq_fn(scale_size[0], ceildiv(mat_size[0], 128)) and eq_fn(
- scale_size[1], ceildiv(mat_size[1], 128)
- ):
- return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE
- # Adjust for packed FP4 data (2 values per element)
- K_multiplier = 2 if mat_dtype == torch.float4_e2m1fn_x2 else 1
- # NVFP4: BlockWise1x16 with float8_e4m3fn scales
- if mat_dtype == torch.float4_e2m1fn_x2 and scale_dtype == torch.float8_e4m3fn:
- expected_numel_a = _round_up(mat_size[0], 128) * _round_up(
- ceildiv(K_multiplier * mat_size[1], 16), 4
- )
- expected_numel_b = _round_up(mat_size[1], 128) * _round_up(
- ceildiv(K_multiplier * mat_size[0], 16), 4
- )
- if eq_fn(scale_numel, expected_numel_a) or eq_fn(scale_numel, expected_numel_b):
- return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4
- # MXFP8: BlockWise1x32 with float8_e8m0fnu scales
- if scale_dtype == torch.float8_e8m0fnu:
- if not torch.version.hip:
- # NVIDIA: uses swizzled 32x4x4 layout
- expected_numel_a = _round_up(mat_size[0], 128) * _round_up(
- ceildiv(K_multiplier * mat_size[1], 32), 4
- )
- expected_numel_b = _round_up(mat_size[1], 128) * _round_up(
- ceildiv(K_multiplier * mat_size[0], 32), 4
- )
- if eq_fn(scale_numel, expected_numel_a) or eq_fn(
- scale_numel, expected_numel_b
- ):
- return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4
- else:
- # AMD: no swizzle
- expected_numel_a = ceildiv(mat_size[0], 32) * K_multiplier * mat_size[1]
- expected_numel_b = ceildiv(K_multiplier * mat_size[1], 32) * mat_size[0]
- if eq_fn(scale_numel, expected_numel_a) or eq_fn(
- scale_numel, expected_numel_b
- ):
- return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE
- return None, None
- def infer_scale_swizzle(
- mat: torch.Tensor, scale: torch.Tensor
- ) -> tuple[Optional[Any], Optional[Any]]:
- """
- Infer the scaling type and swizzle mode from matrix and scale tensor shapes/dtypes.
- This function determines how scale factors are laid out relative to the matrix:
- - TensorWise: Single scale for entire tensor
- - RowWise: One scale per row
- - BlockWise1x128/128x128: Block-scaled with float32 scales
- - BlockWise1x32: MXFP8 with float8_e8m0fnu scales (swizzled on NVIDIA)
- - BlockWise1x16: NVFP4 with float8_e4m3fn scales (swizzled)
- Args:
- mat: The matrix tensor (FP8 or FP4)
- scale: The scale factor tensor
- Returns:
- Tuple of (ScalingType, SwizzleType) or (None, None) if unrecognized
- """
- return _infer_scale_swizzle_impl(
- mat_size=(mat.shape[0], mat.shape[1]),
- scale_size=tuple(scale.shape),
- scale_numel=scale.numel(),
- mat_dtype=mat.dtype,
- scale_dtype=scale.dtype,
- eq_fn=lambda a, b: a == b,
- )
- def infer_scale_swizzle_ir(
- mat: Buffer,
- scale: Buffer,
- transpose: bool = False,
- ) -> tuple[Optional[Any], Optional[Any]]:
- """
- Infer the scaling type and swizzle mode for IR nodes (used during graph lowering).
- This is the IR-compatible version of infer_scale_swizzle, using symbolic
- size comparisons via V.graph.sizevars.statically_known_equals.
- """
- from torch._inductor.virtualized import V
- mat_size = mat.get_size()
- scale_size = scale.get_size()
- # Handle transposed matrix
- if transpose:
- mat_size = (mat_size[1], mat_size[0])
- # Compute scale numel symbolically
- scale_numel = functools.reduce(operator.mul, scale_size, 1) if scale_size else 1
- def symbolic_eq(a: Any, b: Any) -> bool:
- """Compare values using symbolic equality when possible."""
- return V.graph.sizevars.statically_known_equals(a, b)
- return _infer_scale_swizzle_impl(
- mat_size=(mat_size[0], mat_size[1]) if len(mat_size) >= 2 else (mat_size[0], 1),
- scale_size=tuple(scale_size),
- scale_numel=scale_numel,
- mat_dtype=mat.dtype,
- scale_dtype=scale.dtype,
- eq_fn=symbolic_eq,
- )
|