| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241 |
- """
- Utility functions and classes used throughout the TorchDynamo system.
- This module contains a collection of helper utilities used by various parts of Dynamo for:
- - Performance metrics collection and reporting
- - Compilation timing and debugging
- - Graph manipulation and tensor operations
- - Runtime guards and checks
- - Common data structure operations
- - Testing and development tools
- This is an internal module that provides shared functionality used across the Dynamo codebase.
- """
- from __future__ import annotations
- import atexit
- import collections
- import contextlib
- import copy
- import dataclasses
- import datetime
- import dis
- import enum
- import functools
- import gc
- import importlib
- import inspect
- import itertools
- import json
- import linecache
- import logging
- import math
- import operator
- import os
- import re
- import sys
- import textwrap
- import threading
- import time
- import traceback
- import types
- import typing
- import uuid
- import warnings
- import weakref
- from collections import Counter, OrderedDict
- from contextlib import AbstractContextManager, contextmanager
- from dataclasses import is_dataclass
- from functools import lru_cache
- from types import CodeType, MethodWrapperType
- from typing import (
- Any,
- cast,
- ClassVar,
- Generic,
- Literal,
- NoReturn,
- Optional,
- overload,
- TypeAlias,
- TypeGuard,
- TypeVar,
- Union,
- )
- from typing_extensions import ParamSpec, TypeIs
- import torch
- import torch._functorch.config
- import torch.fx.experimental.symbolic_shapes
- import torch.utils._pytree as pytree
- from torch import fx
- from torch._C import (
- _instruction_counter,
- _len_torch_function_stack,
- _pop_torch_function_stack,
- _push_on_torch_function_stack,
- )
- from torch._dispatch.python import enable_python_dispatcher
- from torch._dynamo.metrics_context import MetricsContext, RuntimeMetricsContext
- from torch._guards import CompileId, Source, TracingContext
- from torch._subclasses.meta_utils import is_sparse_compressed
- from torch._utils_internal import (
- justknobs_check,
- log_chromium_event_internal,
- log_compilation_event,
- record_chromium_event_internal,
- signpost_event,
- )
- from torch.fx._utils import _format_graph_code, lazy_format_graph_code
- from torch.monitor import _WaitCounter
- from torch.nn.modules.lazy import LazyModuleMixin
- from torch.utils._ordered_set import OrderedSet
- from torch.utils._python_dispatch import is_traceable_wrapper_subclass
- from torch.utils._triton import has_triton, has_triton_package
- from torch.utils.hooks import RemovableHandle
- from .graph_utils import _get_flat_args
- if typing.TYPE_CHECKING:
- from collections.abc import (
- Callable,
- Container,
- Generator,
- ItemsView,
- Iterable,
- Iterator,
- KeysView,
- Mapping,
- Sequence,
- ValuesView,
- )
- from torch._dynamo.bytecode_transformation import Instruction
- from torch._dynamo.replay_record import ExecutionRecord
- from torch._dynamo.symbolic_convert import (
- InstructionTranslator,
- InstructionTranslatorBase,
- )
- from torch._dynamo.variables.base import VariableTracker
- from torch._prims_common import DeviceLikeType
- from torch._subclasses import FakeTensorMode
- try:
- import numpy as np
- except ModuleNotFoundError:
- np = None # type: ignore[assignment]
- try:
- import torch._logging
- import torch._numpy as tnp
- from torch._guards import detect_fake_mode # noqa: F401
- from torch._logging import LazyString
- from . import config
- # NOTE: Make sure `NP_SUPPORTED_MODULES` and `NP_TO_TNP_MODULE` are in sync.
- if np:
- NP_SUPPORTED_MODULES: tuple[types.ModuleType, ...] = (
- np,
- np.fft,
- np.linalg,
- np.random,
- )
- NP_TO_TNP_MODULE = {
- np: tnp,
- np.fft: tnp.fft,
- np.linalg: tnp.linalg,
- np.random: tnp.random,
- }
- else:
- NP_SUPPORTED_MODULES = ()
- # pyrefly: ignore [implicit-any]
- NP_TO_TNP_MODULE = {}
- from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
- except ImportError:
- pass
- T = TypeVar("T")
- R = TypeVar("R")
- _P = ParamSpec("_P")
- unpatched_nn_module_getattr = torch.nn.Module.__getattr__
- unpatched_nn_module_call = torch.nn.Module.__call__
- unpatched_nn_module_call_impl = torch.nn.Module._call_impl
- counters: collections.defaultdict[str, Counter[str]] = collections.defaultdict(
- collections.Counter
- )
- optimus_scuba_log: dict[str, Any] = {}
- troubleshooting_url = "https://docs.pytorch.org/docs/main/user_guide/torch_compiler/compile/programming_model.recompilation.html"
- nnmodule_doc_url = "https://docs.pytorch.org/docs/main/user_guide/torch_compiler/torch.compiler_nn_module.html"
- nnmodule_doc_url_msg = f"See {nnmodule_doc_url} for more information and limitations."
- log = logging.getLogger(__name__)
- # profiling compilation time by function
- compilation_time_metrics: dict[str, list[float]] = {}
- # This supports calculate_time_spent(), which reports cumulative times
- # across the process for any "phase" populated by dynamo_timed. Reset if
- # reset_frame_count() is called.
- cumulative_time_spent_ns: dict[str, float] = collections.defaultdict(float)
- timer_counter = itertools.count()
- # Abstraction on top of counters.
- class ReInplaceTrigger(enum.Enum):
- AUTO_FUNC_V1 = 1
- AUTO_FUNC_V2 = 2
- TRITON_OPS = 3
- class ReinplaceCounters:
- _values: collections.defaultdict[str, int] = collections.defaultdict(int)
- # Track sizes of known not re-inplaced tensors (exclude dynamic shapes).
- @classmethod
- def add_missed_bytes(cls, trigger: ReInplaceTrigger, bytes: int) -> None:
- if bytes != 0:
- cls._values[f"missed_bytes_{trigger.name}"] += bytes
- # Track number of not re-inplaced tensors.
- @classmethod
- def add_missed_opportunities(cls, trigger: ReInplaceTrigger, count: int) -> None:
- if count != 0:
- cls._values[f"missed_tensors_{trigger.name}"] += count
- @classmethod
- def clear(cls) -> None:
- cls._values.clear()
- @classmethod
- def get_total_missed(cls) -> int:
- sum = 0
- for trigger in ReInplaceTrigger:
- sum += cls._values.get(f"missed_tensors_{trigger.name}", 0)
- return sum
- @classmethod
- def get_total_missed_bytes(cls) -> int:
- sum = 0
- for trigger in ReInplaceTrigger:
- sum += cls._values.get(f"missed_bytes_{trigger.name}", 0)
- return sum
- @classmethod
- def log(cls) -> None:
- # if not empty log.
- if cls._values:
- signpost_event("inductor", "reinplace_counters", cls._values)
- def tabulate(
- rows: Union[list[tuple[str, Any]], list[list[Any]]],
- headers: Union[tuple[str, ...], list[str]],
- ) -> str:
- try:
- import tabulate
- return tabulate.tabulate(rows, headers=headers)
- except ImportError:
- return "\n".join(
- ", ".join(map(str, row)) for row in itertools.chain([headers], rows)
- )
- curr_frame = 0
- # Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
- def increment_frame() -> None:
- global curr_frame
- curr_frame = curr_frame + 1
- # Note: Called for you by dynamo - you almost never ever want to invoke this yourself.
- def reset_frame_count() -> None:
- global curr_frame
- cumulative_time_spent_ns.clear()
- compilation_time_metrics.clear()
- curr_frame = 0
- _recompile_user_contexts: Optional[list[Callable[[], str]]] = None
- def register_hook_for_recompile_user_context(hook: Callable[[], str]) -> None:
- """
- Register a hook to be called when a recompile is triggered. The hook
- should return a string describing user contexts that are not available
- to the compiler, such as the current training epoch. This is useful for
- debugging and data analysis for recompile. For data retention purposes,
- the user context string is capped at 256 characters.
- """
- global _recompile_user_contexts
- if _recompile_user_contexts is None:
- _recompile_user_contexts = []
- _recompile_user_contexts.append(hook)
- def get_hook_for_recompile_user_context() -> Optional[list[Callable[[], str]]]:
- return _recompile_user_contexts
- def reset_recompile_user_contexts() -> None:
- """Clear any registered recompile user-context hooks (test helper)."""
- global _recompile_user_contexts
- _recompile_user_contexts = None
- op_count = 0
- def increment_op_count(cnt: int) -> None:
- global op_count
- op_count += cnt
- # Get the total time in seconds for each "phase"
- # For example, {'entire_frame_compile':8.574629999999999, 'backend_compile':5.26806}
- def calculate_time_spent() -> dict[str, float]:
- total_by_key = {}
- for phase, timing in cumulative_time_spent_ns.items():
- # pyrefly: ignore [unsupported-operation]
- total_by_key[phase] = timing / 1e9
- total_by_key["total_wall_time"] = total_by_key.get(
- "entire_frame_compile", 0
- ) + total_by_key.get("entire_backward_compile", 0)
- # pyrefly: ignore [bad-return]
- return total_by_key
- # Print a report of time spent so far
- # Ex:
- # TIMING:
- # entire_frame_compile:8.574629999999999
- # backend_compile:5.26806
- def print_time_report() -> None:
- total_by_key = calculate_time_spent()
- out = "TIMING:"
- for key, value in total_by_key.items():
- out = f"{out} {key}:{round(value, 5)}"
- print(out)
- # Use the following singleton to capture and log CompilationMetrics. Entering the context
- # manager allocates a new record to be logged when it exits. (You should not need to use
- # this directly unless you introduce a new code path where compilation metrics would be
- # gathered). While compiling, use the setters or timer in MetricsContext to update fields
- # in the current context. For example:
- #
- # To set a single field once (use overwrite=True to overwrite):
- # get_metrics_context().set("metric_name", value)
- #
- # To set multiple fields at once (use overwrite=True to overwrite):
- # get_metrics_context().update({"name1": val1, "name2": val2})
- #
- # To increment an integer field:
- # get_metrics_context().increment("metric_name", value)
- #
- # To record execution time, MetricsContext works with dynamo_timed:
- # def foo(...):
- # # Updates the "metric_us" field.
- # with dynamo_timed("metric", dynamo_compile_column_us="metric_us")
- # ...
- #
- _metrics_context_tls = threading.local()
- def get_metrics_context() -> MetricsContext:
- if not hasattr(_metrics_context_tls, "metrics_context"):
- _metrics_context_tls.metrics_context = MetricsContext(
- on_exit=record_compilation_metrics
- )
- return _metrics_context_tls.metrics_context
- def get_runtime_metrics_context() -> RuntimeMetricsContext:
- if not hasattr(_metrics_context_tls, "runtime_metrics_context"):
- _metrics_context_tls.runtime_metrics_context = RuntimeMetricsContext(
- on_exit=record_compilation_metrics
- )
- return _metrics_context_tls.runtime_metrics_context
- class CompileEventLogLevel(enum.Enum):
- """
- Enum that loosely corresponds with a "log level" of a given event.
- CHROMIUM_EVENT: Logs only to tlparse.
- COMPILE_EVENT: Logs to tlparse + PT2 Compile Events
- COMPILATION_METRIC: Logs to tlparse, PT2 Compile Events, and dynamo_compile
- """
- CHROMIUM = 1
- PT2_COMPILE = 2
- COMPILATION_METRIC = 3
- class CompileEventLogger:
- """
- Helper class for representing adding metadata(i.e. columns) to various compile events.
- Use CompileEventLogger to add event data to:
- - Chromium events
- - PT2 Compile Events
- - CompilationMetrics
- This should be used in conjunction with dynamo_timed() and metrics contexts, which create
- timed spans and events. CompileEventLogger uses three log levels (described in CompileEventLogLevel),
- where each log level logs to all sources below it in the hierarchy.
- Example usages:
- - I want to log to an existing chromium event within dynamo timed:
- with dynamo_timed("my_event"):
- CompileEventLogger.chromium("my_event", foo=bar)
- - I want to log my event to both chromium + pt2_compile_events:
- with dynamo_timed("my_event", log_pt2_compile_event=True):
- CompileEventLogger.pt2_compile("my_event", foo=bar)
- - I want to add information to dynamo events and dynamo_compile
- CompileEventLogger.compilation_metric(foo=bar)
- """
- @staticmethod
- def log_instant_event(
- event_name: str,
- metadata: dict[str, Any],
- time_ns: Optional[int] = None,
- log_level: CompileEventLogLevel = CompileEventLogLevel.CHROMIUM,
- ) -> None:
- if time_ns is None:
- time_ns = time.time_ns()
- chromium_log = get_chromium_event_logger()
- if log_level == CompileEventLogLevel.CHROMIUM:
- log_pt2_compile_event = False
- elif log_level == CompileEventLogLevel.PT2_COMPILE:
- log_pt2_compile_event = True
- else:
- raise RuntimeError(
- "Cannot log instant event at COMPILATION_METRIC level. Please choose one of CHROMIUM_EVENT or COMPILE_EVENT"
- )
- chromium_log.log_instant_event(
- event_name, time_ns, metadata, log_pt2_compile_event
- )
- @staticmethod
- def add_data(
- event_name: str,
- log_level: CompileEventLogLevel,
- overwrite: bool = False,
- **metadata: object,
- ) -> None:
- """
- Centralized API for adding data to various events
- Log an event to a toplevel "dynamo" event or metrics context
- depending on log level.
- """
- chromium_log = get_chromium_event_logger()
- pt2_compile_substack = chromium_log.get_pt2_compile_substack()
- if log_level == CompileEventLogLevel.CHROMIUM:
- chromium_log.add_event_data(event_name, **metadata)
- elif log_level == CompileEventLogLevel.PT2_COMPILE:
- pt2_compile_substack = chromium_log.get_pt2_compile_substack()
- if event_name not in pt2_compile_substack:
- raise RuntimeError(
- "Error: specified log level PT2_COMPILE, but the event %s"
- " is not logged to pt2_compile_events. Make sure the event is active and you passed "
- "log_pt2_compile_event=True to dynamo_timed",
- event_name,
- )
- chromium_log.add_event_data(event_name, **metadata)
- else:
- assert log_level == CompileEventLogLevel.COMPILATION_METRIC
- top_event = chromium_log.get_outermost_event()
- if event_name != top_event:
- raise RuntimeError(
- "Log level is COMPILATION_METRIC, but event_name isn't the toplevel event. "
- "CompilationMetrics must be logged to the toplevel event. Consider using `log_toplevel_event_data` directly."
- )
- metrics_context = get_metrics_context()
- if not metrics_context.in_progress():
- raise RuntimeError(
- "No metrics context is in progress. Please only call this function within a metrics context."
- )
- # TODO: should we assert that the keys of metadata are in CompilationMetrics?
- metrics_context.update(metadata, overwrite)
- chromium_log.add_event_data(event_name, **metadata)
- @staticmethod
- def add_toplevel(
- log_level: CompileEventLogLevel, overwrite: bool = False, **metadata: object
- ) -> None:
- """
- Syntactic sugar for logging to the toplevel event
- """
- top_event = get_chromium_event_logger().get_outermost_event()
- if top_event is None:
- raise RuntimeError(
- "No toplevel event active. Please only call this function within a dynamo_timed context."
- )
- CompileEventLogger.add_data(top_event, log_level, overwrite, **metadata)
- @staticmethod
- def increment(
- event_name: str, log_level: CompileEventLogLevel, key: str, value: int
- ) -> None:
- """
- Increments an existing field, or adds it
- """
- chromium_log = get_chromium_event_logger()
- if (
- log_level == CompileEventLogLevel.CHROMIUM
- or log_level == CompileEventLogLevel.PT2_COMPILE
- ):
- chromium_log.increment(event_name, key, value)
- else:
- assert log_level == CompileEventLogLevel.COMPILATION_METRIC
- top_event = chromium_log.get_outermost_event()
- if event_name != top_event:
- raise RuntimeError(
- "Log level is COMPILATION_METRIC, but event_name isn't the toplevel event. "
- "CompilationMetrics must be logged to the toplevel event. Consider using `increment_toplevel` directly."
- )
- metrics_context = get_metrics_context()
- if not metrics_context.in_progress():
- raise RuntimeError(
- "No metrics context is in progress. Please only call this function within a metrics context/dynamo_timed."
- )
- metrics_context.increment(key, value)
- chromium_log.increment(event_name, key, value)
- @staticmethod
- def increment_toplevel(
- key: str,
- value: int = 1,
- log_level: CompileEventLogLevel = CompileEventLogLevel.COMPILATION_METRIC,
- ) -> None:
- """
- Increments a value on the toplevel metric. By default, logs to metric.
- """
- chromium_log = get_chromium_event_logger()
- top_event = chromium_log.get_outermost_event()
- if top_event is None:
- raise RuntimeError(
- "No toplevel event active. Please only call this function within a metrics context/dynamo_timed."
- )
- CompileEventLogger.increment(top_event, log_level, key, value)
- @staticmethod
- def add_to_set(
- event_name: str, log_level: CompileEventLogLevel, key: str, value: Any
- ) -> None:
- """
- Add metadata <value> to a set of values with key <key>. Creates a set if it doesn't exist.
- """
- chromium_log = get_chromium_event_logger()
- if (
- log_level == CompileEventLogLevel.CHROMIUM
- or log_level == CompileEventLogLevel.PT2_COMPILE
- ):
- chromium_log.add_to_set(event_name, key, value)
- else:
- assert log_level == CompileEventLogLevel.COMPILATION_METRIC
- top_event = chromium_log.get_outermost_event()
- if event_name != top_event:
- raise RuntimeError(
- "Log level is COMPILATION_METRIC, but event_name isn't the toplevel event. "
- "CompilationMetrics must be logged to the toplevel event. Consider using `add_to_set_metric` directly."
- )
- metrics_context = get_metrics_context()
- if not metrics_context.in_progress():
- raise RuntimeError(
- "No metrics context is in progress. Please only call this function within a metrics context/dynamo_timed."
- )
- metrics_context.add_to_set(key, value)
- chromium_log.add_to_set(event_name, key, value)
- @staticmethod
- def add_to_set_toplevel(
- key: str,
- value: Any,
- log_level: CompileEventLogLevel = CompileEventLogLevel.COMPILATION_METRIC,
- ) -> None:
- """
- Same as add to set, just does it automatically to the toplevel event instead of having to explicitly name it.
- Defaults to COMPILATION_METRIC log level.
- """
- chromium_log = get_chromium_event_logger()
- top_event = chromium_log.get_outermost_event()
- if top_event is None:
- raise RuntimeError(
- "No toplevel event active. Please only call this function within a metrics context/dynamo_timed."
- )
- CompileEventLogger.add_to_set(top_event, log_level, key, value)
- # Helper functions that are syntactic sugar
- @staticmethod
- def chromium(event_name: str, **metadata: object) -> None:
- """
- Add <metadata> to <event_name> in chromium. Each key/value of metadata will appear in the chromium trace.
- <event_name> should be the name of a timed event span passed to `dynamo_timed`.
- """
- CompileEventLogger.add_data(
- event_name, CompileEventLogLevel.CHROMIUM, overwrite=False, **metadata
- )
- @staticmethod
- def pt2_compile(event_name: str, **metadata: object) -> None:
- """
- Add <metadata> to <event_name> in chromium and PT2 Compile Events.
- Each key/value of metadata will appear in the chromium trace. Each kwarg name becomes
- a column in PT2 Compile Events, with the corresponding kwarg value.
- <event_name> should be the name of a timed event span passed to `dynamo_timed`,
- with log_to_pt2_compile_events=True.
- """
- CompileEventLogger.add_data(
- event_name, CompileEventLogLevel.PT2_COMPILE, overwrite=False, **metadata
- )
- @staticmethod
- def add_record_function_data(event_name: str, **metadata: object) -> None:
- """
- Add record function data to the profiler event.
- This emits profiler event data so compilation events show up in stack profilers
- like the PyTorch profiler.
- Args:
- event_name: Name of the event to record
- **metadata: Additional metadata to attach to the record function
- """
- if torch.autograd.profiler._is_profiler_enabled and metadata:
- metadata_str = ", ".join(f"{k}={v}" for k, v in metadata.items())
- with torch.autograd.profiler.record_function(
- f"{event_name}_data: {metadata_str}"
- ):
- pass
- @staticmethod
- def compilation_metric(overwrite: bool = False, **metadata: object) -> None:
- """
- Add <metadata> to the CompilationMetrics context. Also logs to PT2 Compile Events
- and chromium.
- Each key/value of metadata will appear in the chromium trace. Each kwarg name becomes
- a column in PT2 Compile Events and Dynamo Compile, with the corresponding kwarg value.
- """
- CompileEventLogger.add_toplevel(
- CompileEventLogLevel.COMPILATION_METRIC, overwrite, **metadata
- )
- @staticmethod
- def instant(
- event_name: str, metadata: dict[str, Any], time_ns: Optional[int] = None
- ) -> None:
- """
- Log an instant event to chromium logs with name <event_name> at time <time_ns>. The `args` field in
- Perfetto will point to metadata. <time_ns> should be a value obtained from time.time_ns().
- """
- CompileEventLogger.log_instant_event(
- event_name, metadata, time_ns, CompileEventLogLevel.CHROMIUM
- )
- @staticmethod
- def try_add_pt2_compile(event_name: str, **metadata: object) -> None:
- """
- Adds to an existing pt2_compile event, but silently returns if the event doesn't exist
- or ChromiumEventLogger is not initialized.
- This function is syntactic sugar for chromium_event_logger().try_add_event_data.
- """
- if not chromium_event_log_active():
- return
- chromium_log = get_chromium_event_logger()
- chromium_log.try_add_event_data(event_name, **metadata)
- @staticmethod
- def try_(method_fn: Callable[_P, Any], *args: _P.args, **kwargs: _P.kwargs) -> None:
- """
- Special function that quietly runs a given method, returning if CHROMIUM_EVENT_LOG is None or metrics context is not set
- """
- if not chromium_event_log_active():
- return
- metrics_context = get_metrics_context()
- if not metrics_context.in_progress():
- return
- method_fn(*args, **kwargs)
- _dynamo_timed_tls = threading.local()
- @contextmanager
- def compile_time_record_function(name: str) -> Generator[Any, None, None]:
- """
- A context manager for compile-time profiling that uses _RecordFunctionFast
- for lower overhead than torch.profiler.record_function.
- This is intended for use during compilation (dynamo, inductor, etc.) where
- we want profiling support but with minimal overhead. Moreover, we do not
- want the record_function call inside torch.compile to be dispatched.
- Args:
- name: The name of the record function event that will appear in profiles.
- """
- if torch.autograd.profiler._is_profiler_enabled:
- rf = torch._C._profiler._RecordFunctionFast(name)
- rf.__enter__()
- try:
- yield
- finally:
- rf.__exit__(None, None, None)
- else:
- yield
- @contextmanager
- def dynamo_timed(
- key: str,
- # TODO(masneral): Deprecate this param.
- phase_name: Optional[str] = None,
- log_pt2_compile_event: bool = False,
- metadata: Optional[dict[str, object]] = None,
- dynamo_compile_column_us: Optional[str] = None,
- compile_id: Optional[CompileId] = None,
- is_backward: Optional[bool] = None,
- log_waitcounter: bool = False,
- waitcounter_name_override: Optional[str] = None,
- ) -> Generator[Any, None, None]:
- """
- dynamo_timed is a context manager
- By wrapping a function in dynamo_timed, we can get a few things:
- 1) Optionally log timings to pt2_compile_events.
- 2) Optionally log timings to CompilationMetrics (dynamo_compile).
- 3) Optionally log chromium events.
- 4) Optionally increment a WaitCounter.
- 5) Store a record in compilation_time_metrics
- For example:
- def _foo(...):
- with dynamo_timed("_foo"):
- ...
- Would show up as an entry in our timing dict:
- OrderedDict([('_foo', [0.083690, 0.23949, 3.1425e-05])])
- This is extremely useful for granular debugging.
- Although it is tempting to use dynamo_timed as a decorator, please do not.
- In its decorator form it makes cProfile traces less useful as dynamo_timed
- suddenly becomes a bottleneck for lots of function calls (as only one parent
- pointer is recorded).
- Params:
- - key: key into compile_time_metrics. If phase_name is not provided, this is
- also the event name used for pt2_compile_events logs and chromium events.
- - phase_name: Optional override for the event name.
- - log_pt2_compile_event: Whether to log a pt2 compile event internally.
- - metadata: Extra metadata to put in pt2_compile_events.
- - dynamo_compile_column_us: If provided, updates the specified CompilationMetrics
- field to be logged to dyname_compile column. We expect all columns to be _us;
- therefore, the field name must end with "_us".
- - compile_id: In the typical case, this parameter should not be needed. Use to
- supply the compile_id for those cases where we want to log a compile_id where
- it's not naturally available, e.g., for runtime autotuning.
- - is_backward: Specify forward/backward directly when not available in a
- CompileContext, e.g., during runtime autotuning.
- that support it.
- - log_waitcounter: If set, we'll log a waitcounter of the form "pytorch.dynamo_timed.{key}"
- """
- if phase_name:
- event_name = phase_name
- fn_name = key
- else:
- event_name = key
- fn_name = None
- if key not in compilation_time_metrics:
- compilation_time_metrics[key] = []
- metrics = compilation_time_metrics[key]
- event_metadata = {}
- if metadata:
- event_metadata.update(metadata)
- if fn_name:
- event_metadata.update({"fn_name": fn_name})
- if is_backward is not None:
- event_metadata.update({"is_backward": is_backward})
- chromium_log: ChromiumEventLogger = get_chromium_event_logger()
- start_ns = time.time_ns()
- chromium_log.log_event_start(
- event_name, start_ns, event_metadata, log_pt2_compile_event, compile_id
- )
- cx_mgrs: list[typing.Any] = [compile_time_record_function(f"{key} (dynamo_timed)")]
- if log_waitcounter:
- wc_name = waitcounter_name_override if waitcounter_name_override else key
- cx_mgrs.append(_WaitCounter(f"pytorch.wait_counter.{wc_name}").guard())
- is_compile_time = torch._guards.CompileContext.current_compile_id() is not None
- if dynamo_compile_column_us:
- # We're standardizing on microseconds for dynamo_compile timings.
- assert dynamo_compile_column_us.endswith("_us")
- # Track nested dynamo_timed calls that update CompilationMetrics so we can
- # bump a total duration only for the outermost metric.
- if not hasattr(_dynamo_timed_tls, "depth"):
- _dynamo_timed_tls.depth = 0
- _dynamo_timed_tls.depth += 1
- # The corresponding WaitCounters that we bump for all overheads
- if _dynamo_timed_tls.depth == 1:
- cx_mgrs.append(_WaitCounter("pytorch.wait_counter.dynamo_compile").guard())
- if not is_compile_time:
- runtime_wc = "pytorch.wait_counter.compile_runtime_overheads"
- cx_mgrs.append(_WaitCounter(runtime_wc).guard())
- try:
- with contextlib.ExitStack() as stack:
- for cx in cx_mgrs:
- stack.enter_context(cx)
- yield
- finally:
- end_ns = time.time_ns()
- time_spent_ns = end_ns - start_ns
- metrics.append(time_spent_ns / 1e9)
- chromium_log.log_event_end(
- event_name, end_ns, {}, start_ns, log_pt2_compile_event, compile_id
- )
- if dynamo_compile_column_us:
- # TODO: the events that we capture in calculate_time_spent() seem a little
- # arbitrary. Currently, it's only those fields that are present in
- # CompilationMetrics (but note that we accumulate by the associated event
- # name, not the field name in CompilationMetrics). Do we want to keep it
- # this way?
- cumulative_time_spent_ns[event_name] += time_spent_ns
- # Bump the total duration for every outer event.
- _dynamo_timed_tls.depth -= 1
- is_outer_event = _dynamo_timed_tls.depth == 0
- duration_us = time_spent_ns // 1000
- if is_compile_time:
- metrics_context = get_metrics_context()
- if metrics_context.in_progress():
- metrics_context.increment(dynamo_compile_column_us, duration_us)
- if is_outer_event:
- metrics_context.increment("duration_us", duration_us)
- else:
- runtime_context = get_runtime_metrics_context()
- runtime_context.increment(dynamo_compile_column_us, duration_us)
- if is_outer_event:
- extra = {
- "compile_id": compile_id,
- "is_runtime": True,
- "is_forward": not is_backward,
- }
- runtime_context.increment("duration_us", duration_us, extra)
- @overload
- def compile_times(repr: Literal["str"], aggregate: bool = False) -> str: ...
- @overload
- # pyrefly: ignore [inconsistent-overload]
- def compile_times(
- repr: Literal["csv"], aggregate: bool = False
- ) -> tuple[list[str], list[object]]: ...
- def compile_times( # type: ignore[misc]
- repr: str = "str", aggregate: bool = False
- ) -> Union[str, None, tuple[list[str], list[str]]]:
- """
- Get metrics about torchdynamo frontend/backend compilation times.
- Accumulates information from functions tagged with `dynamo_timed`.
- repr='str' returns a printable string for user interaction, and 'csv'
- returns headers, rows which can be logged for output
- aggregate causes values from multiple compilations (e.g. split graphs)
- to be accumulated into one value. If false, expect more than one value
- per metric.
- """
- def fmt_fn(values: list[float], item_fn: Callable[[float], str] = str) -> str:
- if aggregate:
- return item_fn(sum(values))
- return ", ".join(map(item_fn, values))
- if repr == "str":
- rows = [
- (k, fmt_fn(compilation_time_metrics[k], item_fn=lambda x: f"{x:.4f}"))
- for k in compilation_time_metrics
- ]
- out = "TorchDynamo compilation metrics:\n"
- out += tabulate(rows, headers=("Function", "Runtimes (s)"))
- return out
- elif repr == "csv":
- values = [
- fmt_fn(v, item_fn=lambda x: f"{x:.6f}")
- for v in compilation_time_metrics.values()
- ]
- headers = list(compilation_time_metrics.keys())
- return headers, values
- return None
- @atexit.register
- def dump_compile_times() -> None:
- log.info(compile_times(repr="str", aggregate=True))
- tensortype_to_dtype = {
- torch.FloatTensor: (torch.float32, torch.float),
- torch.DoubleTensor: (torch.float64, torch.double),
- torch.HalfTensor: (torch.float16, torch.half),
- torch.BFloat16Tensor: (torch.bfloat16,),
- torch.ByteTensor: (torch.uint8,),
- torch.CharTensor: (torch.int8,),
- torch.LongTensor: (torch.int64, torch.long),
- torch.IntTensor: (torch.int32, torch.int),
- torch.ShortTensor: (torch.int16, torch.short),
- torch.BoolTensor: (torch.bool,),
- }
- class DuplicateWarningChecker:
- def __init__(self, maxsize: int = 4096) -> None:
- self.maxsize = maxsize
- self.reset()
- def reset(self) -> None:
- self.set: OrderedDict[Any, Any] = OrderedDict()
- def add(self, key: Union[str, tuple[object, object]]) -> bool:
- if key in self.set:
- self.set.move_to_end(key, last=True)
- if not config.verbose:
- return False
- else:
- self.set[key] = None
- while len(self.set) > self.maxsize:
- self.set.popitem(last=False)
- return True
- graph_break_dup_warning_checker = DuplicateWarningChecker()
- def setup_compile_debug() -> contextlib.ExitStack:
- compile_debug = os.environ.get("TORCH_COMPILE_DEBUG", "0") == "1"
- if compile_debug:
- return add_file_handler()
- return contextlib.ExitStack()
- def reset_graph_break_dup_checker() -> None:
- graph_break_dup_warning_checker.reset()
- # Matches ANSI escape sequences (CSI)
- ANSI_ESCAPE_PATTERN = re.compile(
- r"""
- \x1B # ESC
- \[ # [
- [0-?]* # Parameter bytes
- [ -/]* # Intermediate bytes
- [@-~] # Final byte
- """,
- re.VERBOSE,
- )
- class StripAnsiFormatter(logging.Formatter):
- """Logging formatter that strips ANSI escape codes."""
- def format(self, record: logging.LogRecord) -> str:
- msg = super().format(record)
- return ANSI_ESCAPE_PATTERN.sub("", msg)
- def add_file_handler() -> contextlib.ExitStack:
- log_path = os.path.join(get_debug_dir(), "torchdynamo")
- os.makedirs(log_path, exist_ok=True)
- log_file_handler = logging.FileHandler(os.path.join(log_path, "debug.log"))
- log_file_handler.setFormatter(StripAnsiFormatter("%(message)s"))
- logger = logging.getLogger("torch._dynamo")
- logger.addHandler(log_file_handler)
- exitstack = contextlib.ExitStack()
- exitstack.callback(lambda: logger.removeHandler(log_file_handler))
- return exitstack
- def setup_log_file() -> contextlib.ExitStack:
- exitstack = contextlib.ExitStack()
- if config.log_file_name is not None:
- log_file_handler = logging.FileHandler(config.log_file_name)
- for logger in torch._logging._internal.get_loggers():
- logger.addHandler(log_file_handler)
- exitstack.callback(lambda: logger.removeHandler(log_file_handler))
- return exitstack
- return exitstack
- def gen_record_file_name(exc: Exception, code: CodeType) -> str:
- return f"{get_debug_dir()}/error_recordings/\
- {code.co_name}_{type(exc).__name__}_{code.co_firstlineno}.rec"
- def write_record_to_file(filename: str, exec_record: ExecutionRecord) -> None:
- try:
- if os.path.exists(filename):
- log.warning(
- "Unable to write execution record %s; file already exists.", filename
- )
- else:
- os.makedirs(os.path.dirname(filename), exist_ok=True)
- with open(filename, "wb") as f:
- exec_record.dump(f)
- except Exception:
- log.exception("Unable to write execution record %s", filename)
- def count_calls(g: fx.Graph) -> int:
- c = 0
- for n in g.nodes:
- if "call" in n.op:
- c += 1
- return c
- def identity(x: T) -> T:
- return x
- def hashable(x: Any) -> bool:
- try:
- hash(x)
- return True
- except TypeError:
- return False
- # cannot hash writable memoryview object
- except ValueError:
- return False
- def nothing(*args: Any, **kwargs: Any) -> None:
- pass
- class ExactWeakKeyDictionary:
- """Similar to weakref.WeakKeyDictionary, but use `is`/`id` rather than `==` to compare equality"""
- def __init__(self) -> None:
- self.values: dict[int, Any] = {}
- self.refs: dict[int, weakref.ReferenceType[Any]] = {}
- def __getitem__(self, key: Any) -> Any:
- return self.values[id(key)]
- def get(self, key: Any, default: Any = None) -> Any:
- return self.values.get(id(key), default)
- def __contains__(self, key: Any) -> bool:
- return id(key) in self.values
- def __setitem__(self, key: Any, value: Any) -> None:
- idx = id(key)
- if idx not in self.refs:
- self.refs[idx] = weakref.ref(key, lambda ref: self._remove_id(idx))
- self.values[idx] = value
- def _remove_id(self, idx: int) -> None:
- if idx in self.values:
- del self.values[idx]
- if idx in self.refs:
- del self.refs[idx]
- def clear(self) -> None:
- self.refs.clear()
- self.values.clear()
- @overload
- def istype(obj: object, allowed_types: type[T]) -> TypeIs[T]: ...
- @overload
- def istype(
- obj: object, allowed_types: tuple[type[list[T]], type[tuple[T, ...]]]
- ) -> TypeIs[T]: ...
- @overload
- def istype(obj: object, allowed_types: Iterable[type]) -> bool: ...
- def istype(obj: object, allowed_types: Any) -> bool:
- """isinstance() without subclasses"""
- if isinstance(allowed_types, (tuple, list, set)):
- return type(obj) in allowed_types
- return type(obj) is allowed_types
- _builtin_final_typing_classes: tuple[Any, ...] = tuple()
- if sys.version_info >= (3, 12):
- # Some typing classes moved to C in 3.12,
- # which no longer have the _Final mixin.
- # Check for consistency e.g. here:
- # https://github.com/python/cpython/blob/f2b82b3b3b1f8c7a81e84df35ee921e44517cf32/Lib/typing.py#L32
- _builtin_final_typing_classes = (
- typing.ParamSpecArgs,
- typing.ParamSpecKwargs,
- typing.ParamSpec,
- typing.TypeVar,
- typing.TypeVarTuple,
- typing.TypeAliasType,
- )
- def get_inputs_devices(
- inputs: collections.abc.Sequence[object],
- model: torch.fx.GraphModule,
- ) -> list[Optional[torch.device]]:
- all_inputs = pytree.tree_flatten(inputs)[0] + [
- node.meta["val"] for node in list(model.graph.nodes) if "val" in node.meta
- ]
- devices: list[Optional[torch.device]] = list(
- OrderedSet([i.device for i in all_inputs if hasattr(i, "device")])
- )
- return [
- i for i in devices if (isinstance(i, torch.device) and i.type != "meta")
- ] + [None]
- if sys.version_info >= (3, 14):
- _builtin_final_typing_classes += (typing.Union,)
- def is_typing(value: Any) -> bool:
- # _Final catches most of typing classes:
- # - Any
- # - Callable
- # - Union (Python < 3.14)
- # ...
- #
- # NB: we intentionally ignore classes that inherit from Generic, since they
- # can be used as both TypingVariable as well as UserDefinedClassVariable.
- if sys.version_info >= (3, 12) and isinstance(value, _builtin_final_typing_classes):
- return True
- return (
- isinstance(value, (types.UnionType, typing._Final)) # type: ignore[attr-defined]
- or value is typing.Generic
- or value is typing.Union
- )
- def is_numpy_int_type(value: Any) -> bool:
- if not np:
- return False
- return istype(
- value,
- (
- np.int8,
- np.int16,
- np.int32,
- np.int64,
- np.uint8,
- np.uint16,
- np.uint32,
- np.uint64,
- ),
- )
- def is_numpy_float_type(value: Any) -> bool:
- if not np:
- return False
- return istype(
- value,
- (
- np.float16,
- np.float32,
- np.float64,
- ),
- )
- @overload
- def is_lru_cache_wrapped_function(
- value: Callable[..., T],
- ) -> TypeGuard[functools._lru_cache_wrapper[T]]: ...
- @overload
- def is_lru_cache_wrapped_function(
- value: Any,
- ) -> TypeGuard[functools._lru_cache_wrapper[Any]]: ...
- def is_lru_cache_wrapped_function(
- value: Any,
- ) -> bool:
- return isinstance(value, functools._lru_cache_wrapper) and is_function(
- inspect.getattr_static(value, "__wrapped__")
- )
- _FuncTypes: TypeAlias = Union[
- types.FunctionType,
- types.BuiltinFunctionType,
- types.MethodDescriptorType,
- types.WrapperDescriptorType,
- ]
- def is_function_or_wrapper(
- value: Any,
- ) -> TypeIs[Union[_FuncTypes, torch._ops.OpOverloadPacket, torch._ops.OpOverload]]:
- return is_function(value) or isinstance(
- value, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)
- )
- def is_function(
- value: Any,
- ) -> TypeIs[_FuncTypes]:
- return isinstance(
- value,
- (
- types.FunctionType,
- types.BuiltinFunctionType,
- types.MethodDescriptorType,
- types.WrapperDescriptorType,
- ),
- )
- cmp_name_to_op_mapping = {
- "__eq__": operator.eq,
- "__ne__": operator.ne,
- "__lt__": operator.lt,
- "__le__": operator.le,
- "__gt__": operator.gt,
- "__ge__": operator.ge,
- }
- cmp_name_to_op_str_mapping = {
- "__eq__": "==",
- "__ne__": "!=",
- "__lt__": "<",
- "__le__": "<=",
- "__gt__": ">",
- "__ge__": ">=",
- }
- def is_wrapper_or_member_descriptor(
- value: Any,
- ) -> TypeIs[
- Union[
- types.GetSetDescriptorType,
- types.MethodDescriptorType,
- types.WrapperDescriptorType,
- types.MemberDescriptorType,
- types.MethodWrapperType,
- ]
- ]:
- return isinstance(
- value,
- (
- # set up by PyGetSetDef
- types.GetSetDescriptorType,
- # set by PyMethodDef, e.g. list.append
- types.MethodDescriptorType,
- # slots - list.__add__
- types.WrapperDescriptorType,
- # set up by PyMemberDef
- types.MemberDescriptorType,
- # wrapper over C functions
- types.MethodWrapperType,
- ),
- )
- def unwrap_if_wrapper(fn: Any) -> Any:
- return unwrap_with_attr_name_if_wrapper(fn)[0]
- def unwrap_with_attr_name_if_wrapper(fn: Any) -> tuple[Any, Optional[str]]:
- # TODO(anijain2305) - Investigate if we can get rid of this function
- # unpack @torch._dynamo.optimize()(fn) wrapped function
- if is_function(fn) and inspect.getattr_static(fn, "_torchdynamo_inline", False):
- fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
- attr_name = "_torchdynamo_inline"
- else:
- attr_name = None
- return fn, attr_name
- def is_numpy_ndarray(value: Any) -> TypeGuard[np.ndarray]: # type: ignore[type-arg]
- if not np:
- return False
- return istype(value, np.ndarray)
- def istensor(obj: Any) -> bool:
- """Check of obj is a tensor"""
- tensor_list: tuple[type, ...] = (
- torch.Tensor,
- torch.nn.Parameter,
- *config.traceable_tensor_subclasses,
- )
- tensor_list = tensor_list + (torch._subclasses.FakeTensor,)
- return istype(obj, tensor_list)
- def is_lazy_module(mod: Any) -> bool:
- return isinstance(mod, LazyModuleMixin)
- @functools.lru_cache(4096)
- def print_once(*args: Any) -> None:
- print(*args)
- def make_cell(val: Any = None) -> types.CellType:
- """Some black magic to create a cell object that usually only exists in a closure"""
- x = val
- def f() -> Any:
- return x
- assert f.__closure__ is not None and len(f.__closure__) == 1
- return f.__closure__[0]
- def proxy_args_kwargs(args: Any, kwargs: Any) -> tuple[tuple[Any, ...], dict[str, Any]]:
- try:
- proxy_args = tuple(arg.as_proxy() for arg in args)
- proxy_kwargs = {key: arg.as_proxy() for key, arg in kwargs.items()}
- return proxy_args, proxy_kwargs
- except NotImplementedError as e:
- from .exc import unimplemented
- from .variables.base import typestr
- unimplemented(
- gb_type="Failed to convert args/kwargs to proxy",
- context=f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}",
- explanation="Missing `as_proxy()` implementation for some arg/kwarg.",
- hints=[],
- from_exc=e,
- )
- def to_int_ms(v: Optional[float]) -> Optional[int]:
- return None if v is None else int(v * 1000)
- # float64 timestamp has a quarter microsecond precision in 2024, so while
- # this is suboptimal we shouldn't meaningfully lose precision
- def to_int_us(v: Optional[float]) -> Optional[int]:
- return None if v is None else int(v * 1_000_000)
- # Version field added to every log. Increment to make it easier to distinguish new
- # vs. old entries when you make a substantive change to how the logs are populated.
- LOG_FORMAT_VERSION = 3
- @dataclasses.dataclass
- class CompilationMetrics:
- compile_id: Optional[str] = None
- frame_key: Optional[str] = None
- co_name: Optional[str] = None
- co_filename: Optional[str] = None
- co_firstlineno: Optional[int] = None
- cache_size: Optional[int] = None
- accumulated_cache_size: Optional[int] = None
- guard_count: Optional[int] = None
- shape_env_guard_count: Optional[int] = None
- graph_op_count: Optional[int] = None
- graph_node_count: Optional[int] = None
- graph_input_count: Optional[int] = None
- start_time: Optional[float] = None
- entire_frame_compile_time_s: Optional[float] = None
- backend_compile_time_s: Optional[float] = None
- inductor_compile_time_s: Optional[float] = None
- code_gen_time_s: Optional[float] = None
- fail_type: Optional[str] = None
- fail_reason: Optional[str] = None
- fail_user_frame_filename: Optional[str] = None
- fail_user_frame_lineno: Optional[int] = None
- non_compliant_ops: Optional[set[str]] = None
- compliant_custom_ops: Optional[set[str]] = None
- restart_reasons: Optional[set[str]] = None
- dynamo_time_before_restart_s: Optional[float] = None
- stack_trace: Optional[list[str]] = None
- exception_stack_trace: Optional[list[str]] = None
- graph_node_shapes: Optional[str] = None
- # Sometimes, we will finish analyzing a frame but conclude we don't want
- # to install any guarded code. True means we actually decided to install
- # a compiled frame
- has_guarded_code: Optional[bool] = None
- remote_cache_time_saved_s: Optional[float] = None
- structured_logging_overhead_s: Optional[float] = None
- config_suppress_errors: Optional[bool] = None
- config_inline_inbuilt_nn_modules: Optional[bool] = None
- specialize_float: Optional[bool] = None
- dynamo_config: Optional[str] = None
- compiler_config: Optional[str] = None
- is_forward: Optional[bool] = None
- num_triton_bundles: Optional[int] = None
- remote_fx_graph_cache_get_time_ms: Optional[int] = None
- remote_fx_graph_cache_put_time_ms: Optional[int] = None
- start_time_us: Optional[int] = None
- duration_us: Optional[int] = None
- dynamo_cumulative_compile_time_us: Optional[int] = None
- aot_autograd_cumulative_compile_time_us: Optional[int] = None
- inductor_cumulative_compile_time_us: Optional[int] = None
- inductor_code_gen_cumulative_compile_time_us: Optional[int] = None
- triton_compile_time_us: Optional[int] = None
- runtime_cudagraphify_time_us: Optional[int] = None
- runtime_triton_autotune_time_us: Optional[int] = None
- dynamo_compile_time_before_restart_us: Optional[int] = None
- distributed_ephemeral_timeout_us: Optional[int] = None
- structured_logging_overhead_us: Optional[int] = None
- remote_fx_graph_cache_get_time_us: Optional[int] = None
- remote_fx_graph_cache_put_time_us: Optional[int] = None
- backward_cumulative_compile_time_us: Optional[int] = None
- end_time_us: Optional[int] = None
- pre_grad_pass_time_us: Optional[int] = None
- post_grad_pass_time_us: Optional[int] = None
- joint_graph_pass_time_us: Optional[int] = None
- log_format_version: int = LOG_FORMAT_VERSION
- inductor_config: Optional[str] = None
- remote_cache_version: Optional[int] = None
- inductor_fx_remote_cache_hit_count: Optional[int] = 0
- inductor_fx_remote_cache_miss_count: Optional[int] = 0
- inductor_fx_remote_cache_backend_type: Optional[str] = None
- inductor_fx_remote_cache_hit_keys: Optional[str] = None
- inductor_fx_remote_cache_miss_keys: Optional[str] = None
- inductor_fx_local_cache_hit_count: Optional[int] = 0
- inductor_fx_local_cache_miss_count: Optional[int] = 0
- aotautograd_remote_cache_hit_count: Optional[int] = 0
- aotautograd_remote_cache_miss_count: Optional[int] = 0
- aotautograd_local_cache_hit_count: Optional[int] = 0
- aotautograd_local_cache_miss_count: Optional[int] = 0
- cuda_version: Optional[str] = None
- triton_version: Optional[str] = None
- feature_usage: Optional[dict[str, bool]] = None
- compile_time_autotune_time_us: Optional[int] = None
- is_runtime: Optional[bool] = False
- gc_time_us: Optional[int] = None
- tensorify_float_attempt: Optional[bool] = None
- tensorify_float_success: Optional[bool] = None
- tensorify_float_failure: Optional[set[str]] = None
- guard_latency_us: Optional[float] = None
- recompile_reason: Optional[str] = None
- num_graph_breaks: Optional[int] = None
- triton_kernel_compile_times_us: Optional[str] = None
- ir_count: Optional[int] = None
- cudagraph_skip_reason: Optional[str] = None
- python_version: Optional[str] = None
- pgo_put_remote_code_state_time_us: Optional[int] = None
- pgo_get_remote_code_state_time_us: Optional[int] = None
- # The number of elements within parameters. This is classically what people
- # think of when they think of parameters in a ML model.
- param_numel: Optional[int] = None
- # The number of elements counted by bytes - i.e. a float32 is 4 bytes
- # per element.
- param_bytes: Optional[int] = None
- # The number of parameters counted by fields. This is mostly a proxy for
- # the number of distinct type of params.
- param_count: Optional[int] = None
- recompile_user_contexts: Optional[set[str]] = None
- inline_inbuilt_nn_modules_candidate: Optional[bool] = False
- pytorch_version: Optional[str] = None
- inductor_provenance: Optional[set[str]] = None
- @classmethod
- def create(cls, metrics: dict[str, Any]) -> CompilationMetrics:
- """
- Factory method to create a CompilationMetrics from a dict of fields.
- Includes the logic to add legacy fields and any pre-processing, e.g.,
- we transform some fields to comma-separated strings for scuba logging.
- """
- def us_to_s(metric: Optional[int]) -> Optional[float]:
- return metric / 1e6 if metric is not None else None
- def us_to_ms(metric: Optional[int]) -> Optional[int]:
- return metric // 1000 if metric is not None else None
- def collection_to_str(metric: Optional[Any]) -> Optional[str]:
- def safe_str(item: Any) -> str:
- try:
- return str(item)
- except Exception:
- return "<unknown>"
- if metric is None:
- return None
- if not isinstance(metric, (set, list)):
- return "<unknown>"
- return ",".join(safe_str(item) for item in sorted(metric))
- def collection_to_json_str(metric: Optional[Any]) -> Optional[str]:
- if metric is None:
- return None
- try:
- return json.dumps(list(metric))
- except Exception:
- return "<unknown>"
- # TODO: The following are legacy fields, populated from the fields that replace
- # them. Remove these when we decide we can really deprecate them.
- legacy_metrics = {
- "start_time": us_to_s(metrics.get("start_time_us")),
- "entire_frame_compile_time_s": us_to_s(
- metrics.get("dynamo_cumulative_compile_time_us")
- ),
- "backend_compile_time_s": us_to_s(
- metrics.get("aot_autograd_cumulative_compile_time_us")
- ),
- "inductor_compile_time_s": us_to_s(
- metrics.get("inductor_cumulative_compile_time_us")
- ),
- "code_gen_time_s": us_to_s(
- metrics.get("inductor_code_gen_cumulative_compile_time_us")
- ),
- "remote_cache_time_saved_s": us_to_s(
- metrics.get("distributed_ephemeral_timeout_us")
- ),
- "remote_fx_graph_cache_get_time_ms": us_to_ms(
- metrics.get("remote_fx_graph_cache_get_time_us")
- ),
- "remote_fx_graph_cache_put_time_ms": us_to_ms(
- metrics.get("remote_fx_graph_cache_put_time_us")
- ),
- "structured_logging_overhead_s": us_to_s(
- metrics.get("structured_logging_overhead_us")
- ),
- }
- all_metrics = {**legacy_metrics, **metrics}
- # Processing before logging:
- all_metrics["inductor_fx_remote_cache_hit_keys"] = collection_to_str(
- all_metrics.get("inductor_fx_remote_cache_hit_keys")
- )
- all_metrics["inductor_fx_remote_cache_miss_keys"] = collection_to_str(
- all_metrics.get("inductor_fx_remote_cache_miss_keys")
- )
- all_metrics["triton_kernel_compile_times_us"] = collection_to_json_str(
- all_metrics.get("triton_kernel_compile_times_us")
- )
- compile_id = all_metrics.get("compile_id")
- all_metrics["compile_id"] = str(compile_id) if compile_id else None
- # pyrefly: ignore [bad-argument-type]
- return cls(**all_metrics)
- DEFAULT_COMPILATION_METRICS_LIMIT = 64
- _compilation_metrics: collections.deque[CompilationMetrics] = collections.deque(
- maxlen=DEFAULT_COMPILATION_METRICS_LIMIT
- )
- def add_compilation_metrics_to_chromium(c: CompilationMetrics) -> None:
- """
- These are the common fields in CompilationMetrics that existed before
- metrics_context, and aren't set by MetricsContext.set(). We add the subset
- of them that make sense in `dynamo`/toplevel events in PT2 Compile Events
- directly.
- If you're tempted to add to this list, consider using CompileEventLogger.compilation_metric()
- instead, which will automatically also add it to tlparse and PT2 Compile Events.
- TODO: Get rid of this function and replace it with CompileEventLogger directly instead.
- """
- event_logger = get_chromium_event_logger()
- event_name = event_logger.get_outermost_event()
- if not event_name:
- return
- event_logger.add_event_data(
- event_name=event_name,
- frame_key=c.frame_key,
- co_name=c.co_name,
- co_filename=c.co_filename,
- co_firstlineno=c.co_firstlineno,
- cache_size=c.cache_size,
- accumulated_cache_size=c.accumulated_cache_size,
- guard_count=c.guard_count,
- shape_env_guard_count=c.shape_env_guard_count,
- graph_op_count=c.graph_op_count,
- graph_node_count=c.graph_node_count,
- graph_input_count=c.graph_input_count,
- fail_type=c.fail_type,
- fail_reason=c.fail_reason,
- fail_user_frame_filename=c.fail_user_frame_filename,
- fail_user_frame_lineno=c.fail_user_frame_lineno,
- # Sets aren't JSON serializable
- non_compliant_ops=(
- list(c.non_compliant_ops) if c.non_compliant_ops is not None else None
- ),
- compliant_custom_ops=(
- list(c.compliant_custom_ops) if c.compliant_custom_ops is not None else None
- ),
- restart_reasons=(
- list(c.restart_reasons) if c.restart_reasons is not None else None
- ),
- dynamo_time_before_restart_s=c.dynamo_time_before_restart_s,
- has_guarded_code=c.has_guarded_code,
- dynamo_config=c.dynamo_config,
- )
- def _get_dynamo_config_for_logging() -> Optional[str]:
- def clean_for_json(d: dict[str, Any]) -> dict[str, Any]:
- blocklist = {
- "TYPE_CHECKING",
- "log_file_name",
- "verbose",
- "repro_after",
- "repro_level",
- "repro_forward_only",
- "repro_tolerance",
- "repro_ignore_non_fp",
- "same_two_models_use_fp64",
- "base_dir",
- "debug_dir_root",
- "_save_config_ignore",
- "log_compilation_metrics",
- "inject_BUILD_SET_unimplemented_TESTING_ONLY",
- "_autograd_backward_strict_mode_banned_ops",
- "reorderable_logging_functions",
- "ignore_logger_methods",
- "traceable_tensor_subclasses",
- "nontraceable_tensor_subclasses",
- "_custom_ops_profile",
- }
- return {
- key: sorted(value) if isinstance(value, set) else value
- for key, value in d.items()
- if key not in blocklist
- }
- config_dict = clean_for_json(config.get_config_copy())
- return json.dumps(config_dict, sort_keys=True)
- def _compiler_config_for_logging() -> Optional[str]:
- def clean_for_json(d: dict[str, Any]) -> dict[str, Any]:
- blocklist = {
- "TYPE_CHECKING",
- }
- return {
- key: sorted(value) if isinstance(value, set) else value
- for key, value in d.items()
- if key not in blocklist
- }
- if not torch.compiler.config:
- return None
- try:
- compiler_config_copy = torch.compiler.config.get_config_copy() # type: ignore[attr-defined]
- except (TypeError, AttributeError):
- return "Compiler Config cannot be pickled"
- config_dict = clean_for_json(compiler_config_copy)
- return json.dumps(config_dict, sort_keys=True)
- def _scrubbed_inductor_config_for_logging() -> Optional[str]:
- """
- Method to parse and scrub uninteresting configs from inductor config
- """
- # TypeSafeSerializer for json.dumps()
- # Skips complex types as values in config dict
- class TypeSafeSerializer(json.JSONEncoder):
- def default(self, o: Any) -> Any:
- try:
- return super().default(o)
- except Exception:
- return "Value is not JSON serializable"
- keys_to_scrub: set[Any] = set()
- inductor_conf_str = None
- inductor_config_copy = None
- if torch._inductor.config:
- try:
- inductor_config_copy = torch._inductor.config.get_config_copy()
- except (TypeError, AttributeError, RuntimeError, AssertionError):
- inductor_conf_str = "Inductor Config cannot be pickled"
- if inductor_config_copy is not None:
- try:
- for key, val in inductor_config_copy.items():
- if not isinstance(key, str):
- keys_to_scrub.add(key)
- # Convert set() to list for json.dumps()
- if isinstance(val, set):
- inductor_config_copy[key] = list(val)
- # Evict unwanted keys
- for key in keys_to_scrub:
- del inductor_config_copy[key]
- # Stringify Inductor config
- inductor_conf_str = json.dumps(
- inductor_config_copy,
- cls=TypeSafeSerializer,
- skipkeys=True,
- sort_keys=True,
- )
- except Exception:
- # Don't crash because of runtime logging errors
- inductor_conf_str = "Inductor Config is not JSON serializable"
- return inductor_conf_str
- def record_compilation_metrics(
- start_time_ns: int,
- end_time_ns: int,
- metrics: dict[str, Any],
- exc_type: Optional[type[BaseException]],
- exc_value: Optional[BaseException],
- ) -> None:
- if torch._inductor.utils.should_use_remote_fx_graph_cache():
- try:
- from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION
- remote_cache_version = REMOTE_CACHE_VERSION
- inductor_fx_remote_cache_backend_type = "_ManifoldCache"
- except ModuleNotFoundError:
- remote_cache_version = None
- inductor_fx_remote_cache_backend_type = None
- else:
- inductor_fx_remote_cache_backend_type = None
- remote_cache_version = None
- # Populate the compile_id from the metrics context if it's set. Otherwise,
- # look for it in the current compile context.
- compile_id = metrics.get("compile_id")
- if not compile_id:
- compile_id = torch._guards.CompileContext.current_compile_id()
- common_metrics = {
- "compile_id": compile_id,
- "start_time_us": start_time_ns // 1000,
- "end_time_us": end_time_ns // 1000,
- "fail_type": exc_type.__qualname__ if exc_type else None,
- "fail_reason": str(exc_value) if exc_value else None,
- "structured_logging_overhead_us": to_int_us(
- torch._logging.get_structured_logging_overhead()
- ),
- "dynamo_config": _get_dynamo_config_for_logging(),
- "config_suppress_errors": config.suppress_errors,
- "config_inline_inbuilt_nn_modules": config.inline_inbuilt_nn_modules,
- "inductor_config": _scrubbed_inductor_config_for_logging(),
- "compiler_config": _compiler_config_for_logging(),
- "cuda_version": torch.version.cuda,
- "triton_version": triton.__version__ if has_triton() else "",
- "remote_cache_version": remote_cache_version,
- "inductor_fx_remote_cache_backend_type": inductor_fx_remote_cache_backend_type,
- "python_version": sys.version,
- "pytorch_version": torch.__version__,
- }
- compilation_metrics = CompilationMetrics.create({**common_metrics, **metrics})
- _compilation_metrics.append(compilation_metrics)
- name = "compilation_metrics"
- if compilation_metrics.is_forward is False:
- name = "bwd_" + name
- if compilation_metrics.is_runtime is True:
- name = name + "_runtime"
- torch._logging.trace_structured(
- name,
- lambda: {
- k: list(v) if isinstance(v, set) else v
- for k, v in dataclasses.asdict(compilation_metrics).items()
- },
- # NB: Because compilation metrics *includes* the logging overhead time,
- # we can't both *measure* the logging overhead of compilation metrics
- # without making it inconsistent with compilation metrics itself, so
- # we ignore the (hopefully small) time spent logging compilation metrics
- record_logging_overhead=False,
- # These may be runtime logs, e.g., runtime autotunning, so we provide
- # the CompileId from the compilation metrics in case it's not available
- # in the current trace.
- compile_id=compile_id,
- )
- # If there's a chromium event in flight, add the CompilationMetrics to it.
- add_compilation_metrics_to_chromium(compilation_metrics)
- # Finally log the compilation metrics.
- if config.log_compilation_metrics:
- log_compilation_event(compilation_metrics)
- def set_compilation_metrics_limit(new_size: int) -> None:
- global _compilation_metrics
- while len(_compilation_metrics) > new_size:
- _compilation_metrics.popleft()
- new_deque = collections.deque(_compilation_metrics, maxlen=new_size)
- _compilation_metrics = new_deque
- def clear_compilation_metrics() -> None:
- global _compilation_metrics
- _compilation_metrics.clear()
- def get_compilation_metrics() -> list[CompilationMetrics]:
- return list(_compilation_metrics)
- class ChromiumEventLogger:
- """Logs chromium events to structured logs. tlparse will concatenate these into a perfetto UI link.
- Also emits RecordFunction calls to torch.profiler when enabled.
- See https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview#heading=h.yr4qxyxotyw for
- a specification of the Chromium Event JSON format.
- """
- def get_stack(self) -> list[str]:
- """
- The main event stack, with every chromium event.
- Logged to tlparse.
- """
- if hasattr(self.tls, "stack"):
- return self.tls.stack
- else:
- self.tls.stack = []
- return self.tls.stack
- def get_outermost_event(self) -> Optional[str]:
- """
- Get the outermost event name (i.e. the longest running event)
- or None if the stack is empty.
- """
- stack = self.get_stack()
- return stack[0] if stack else None
- def get_pt2_compile_substack(self) -> list[str]:
- """
- A smaller subset of the main stack that gets used to log
- PT2 Compile Events internally.
- """
- if hasattr(self.tls, "pt2_compile_substack"):
- return self.tls.pt2_compile_substack
- else:
- self.tls.pt2_compile_substack = []
- return self.tls.pt2_compile_substack
- def get_event_data(self) -> dict[str, Any]:
- if not hasattr(self.tls, "event_data"):
- self.tls.event_data = {}
- return self.tls.event_data
- def get_record_functions(self) -> dict[str, AbstractContextManager[None]]:
- if not hasattr(self.tls, "record_functions"):
- self.tls.record_functions = {}
- return self.tls.record_functions
- def __init__(self) -> None:
- self.tls = threading.local()
- from . import config
- # Generate a unique id for this logger, which we can use in scuba to filter down
- # to a single python run.
- if config.pt2_compile_id_prefix:
- self.id_ = f"{config.pt2_compile_id_prefix}-{uuid.uuid4()}"
- else:
- self.id_ = str(uuid.uuid4())
- # TODO: log to init/id tlparse after I add support for it
- log.info("ChromiumEventLogger initialized with id %s", self.id_)
- def try_add_event_data(self, event_name: str, **kwargs: Any) -> None:
- """
- Same as add_event_data, but will silently not log if the event isn't in the stack.
- """
- if event_name not in self.get_stack():
- return
- self.add_event_data(event_name, **kwargs)
- def add_event_data(
- self,
- event_name: str,
- **kwargs: Any,
- ) -> None:
- """
- Adds additional metadata info to an in-progress event
- This metadata is recorded in the END event
- """
- if event_name not in self.get_stack():
- raise RuntimeError(
- f"Event {repr(event_name)} not in {self.get_stack()}. "
- "Cannot add metadata to events that aren't in progress. "
- "Please make sure the event has started and hasn't ended."
- )
- event_data = self.get_event_data()
- if event_name not in event_data:
- event_data[event_name] = {}
- event_data[event_name].update(kwargs)
- def increment(self, event_name: str, key: str, value: int) -> None:
- """
- Increment an integer event data field by the given amount
- """
- if event_name not in self.get_stack():
- raise RuntimeError(
- f"Event {repr(event_name)} not in {self.get_stack()}. "
- "Cannot add metadata to events that aren't in progress. "
- "Please make sure the event has started and hasn't ended."
- )
- event_data = self.get_event_data()
- if event_name not in event_data:
- event_data[event_name] = {}
- if key not in event_data[event_name]:
- event_data[event_name][key] = 0
- event_data[event_name][key] += value
- def add_to_set(
- self,
- event_name: str,
- key: str,
- value: Any,
- ) -> None:
- """
- Add a value to a set within a event_name's metadata if it exists
- """
- if event_name not in self.get_stack():
- raise RuntimeError(
- f"Event {repr(event_name)} not in {self.get_stack()}. "
- "Cannot add metadata to events that aren't in progress. "
- "Please make sure the event has started and hasn't ended."
- )
- event_data = self.get_event_data()
- if event_name not in event_data:
- event_data[event_name] = {}
- if key not in event_data[event_name]:
- event_data[event_name][key] = set()
- event_data[event_name][key].add(value)
- def log_event_start(
- self,
- event_name: str,
- time_ns: int,
- metadata: dict[str, Any],
- log_pt2_compile_event: bool = False,
- compile_id: Optional[CompileId] = None,
- ) -> None:
- """
- Logs the start of a single event.
- :param str event_name Name of event to appear in trace
- :param time_ns Timestamp in nanoseconds
- :param metadata: Any extra metadata associated with this event
- :param log_pt2_compile_event: If True, log to pt2_compile_events
- :param compile_id: Explicit compile_id (rather than using the current context)
- """
- compile_id = compile_id or torch._guards.CompileContext.current_compile_id()
- metadata["compile_id"] = str(compile_id)
- self._log_timed_event(
- event_name,
- time_ns,
- "B",
- metadata,
- )
- self.get_stack().append(event_name)
- # Add metadata from start event
- self.add_event_data(event_name, **metadata)
- if log_pt2_compile_event:
- self.get_pt2_compile_substack().append(event_name)
- # Emit profiler event so compilation events show up in stock PyTorch profiler
- if torch.autograd.profiler._is_profiler_enabled:
- rf = torch._C._profiler._RecordFunctionFast(event_name)
- rf.__enter__()
- self.get_record_functions()[event_name] = rf
- # Add metadata to the profiler event if present
- if metadata:
- CompileEventLogger.add_record_function_data(event_name, **metadata)
- def reset(self) -> None:
- # We this on every compile in case a compile crashes or restarts and we haven't
- # cleared the stack.
- stack = self.get_stack()
- substack = self.get_pt2_compile_substack()
- stack.clear()
- substack.clear()
- event_data = self.get_event_data()
- event_data.clear()
- # Clean up any lingering record functions (shouldn't happen in normal operation)
- record_functions = self.get_record_functions()
- if record_functions:
- for rf in record_functions.values():
- rf.__exit__(None, None, None)
- record_functions.clear()
- def log_event_end(
- self,
- event_name: str,
- time_ns: int,
- metadata: dict[str, Any],
- start_time_ns: int,
- log_pt2_compile_event: bool,
- compile_id: Optional[CompileId] = None,
- ) -> None:
- """
- Logs the end of a single event. This function should only be
- called after log_event_start with the same event_name.
- :param event_name: Name of event to appear in trace
- :param time_ns: Timestamp in nanoseconds
- :param metadata: Any extra metadata associated with this event
- :param start_time_ns: The start time timestamp in nanoseconds
- :param log_pt_compile_event: If True, log to pt2_compile_events
- :param compile_id: Explicit compile_id (rather than using the current context)
- """
- compile_id = compile_id or torch._guards.CompileContext.current_compile_id()
- metadata["compile_id"] = str(compile_id)
- # Grab metadata collected during event span
- all_event_data = self.get_event_data()
- if event_name in all_event_data:
- event_metadata = all_event_data[event_name]
- del all_event_data[event_name]
- else:
- # pyrefly: ignore [implicit-any]
- event_metadata = {}
- # Add the passed in metadata
- event_metadata.update(metadata)
- event = self._log_timed_event(
- event_name,
- time_ns,
- "E",
- event_metadata,
- )
- def pop_stack(stack: list[str]) -> None:
- while event_name != stack[-1]:
- # If the event isn't the most recent one to end, pop
- # off the stack until it is.
- # Since event_name in self.stack, this pop is always safe
- log.warning(
- "ChromiumEventLogger: Detected overlapping events, fixing stack"
- )
- stack.pop()
- event_stack = self.get_stack()
- # These stack health checks currently never happen,
- # but they're written this way to future proof any weird event
- # overlaps in the future.
- if event_name not in event_stack:
- # Something went wrong, we never called start on this event,
- # or it was skipped due to overlapping events below
- log.warning("ChromiumEventLogger: Start event not in stack, ignoring")
- return
- pop_stack(event_stack)
- if log_pt2_compile_event:
- pt2_compile_substack = self.get_pt2_compile_substack()
- pop_stack(pt2_compile_substack)
- log_chromium_event_internal(
- event, pt2_compile_substack, self.id_, start_time_ns
- )
- # Pop actual event off of stack
- pt2_compile_substack.pop()
- # Finally pop the actual event off the stack
- event_stack.pop()
- # End profiler event so compilation events show up in stock PyTorch profiler
- record_functions = self.get_record_functions()
- if event_name in record_functions:
- rf = record_functions.pop(event_name)
- rf.__exit__(None, None, None)
- def _log_timed_event(
- self,
- event_name: str,
- time_ns: int,
- phase: str,
- metadata: Optional[dict[str, Any]] = None,
- ) -> dict[str, Any]:
- """
- Logs a timed event in chromium format. See log_event_start, log_event_end, etc.
- """
- event = {
- "name": event_name,
- "ts": time_ns / 1000, # Chromium events are in micro seconds
- "args": metadata,
- "ph": phase,
- # These categories are needed in all chromium traces
- "cat": "dynamo_timed",
- "tid": 0,
- "pid": 0, # pid should be specified on all logs, we don't personally care about the actual process id
- }
- torch._logging.trace_structured(
- "chromium_event",
- payload_fn=lambda: event,
- suppress_context=False,
- expect_trace_id=False, # Not every chromium event will have a trace_id
- )
- record_chromium_event_internal(event)
- return event
- def log_instant_event(
- self,
- event_name: str,
- time_ns: int,
- metadata: Optional[dict[str, Any]] = None,
- # By default, an instant event isn't logged internally, only to structured logging.
- log_pt2_compile_event: bool = False,
- ) -> None:
- """
- Log an instant event with no associated duration.
- :param str event_name: Name of event to appear in trace
- :param int time_ns Timestamp in nanoseconds
- :param Optional[Dict[str, Any]] metadata: Any extra metadata associated with this event
- :param str cname optional color for the arrow in the trace
- """
- if metadata is None:
- metadata = {}
- compile_id = str(torch._guards.CompileContext.current_compile_id())
- metadata["compile_id"] = compile_id
- event = {
- "name": event_name,
- "ts": time_ns / 1000,
- "args": metadata,
- "ph": "i",
- # These categories are needed in all chromium traces
- "cat": "dynamo_timed",
- "tid": 0,
- "pid": 0,
- "s": "p", # We use "process" level instant events so they all appear on the same row in the trace.
- }
- torch._logging.trace_structured(
- "chromium_event",
- payload_fn=lambda: event,
- suppress_context=False,
- expect_trace_id=True,
- )
- if log_pt2_compile_event:
- # Log an instant event with the same start and end time
- log_chromium_event_internal(
- event, self.get_pt2_compile_substack(), self.id_, time_ns
- )
- CHROMIUM_EVENT_LOG: Optional[ChromiumEventLogger] = None
- def get_chromium_event_logger() -> ChromiumEventLogger:
- global CHROMIUM_EVENT_LOG
- if CHROMIUM_EVENT_LOG is None:
- CHROMIUM_EVENT_LOG = ChromiumEventLogger()
- return CHROMIUM_EVENT_LOG
- def chromium_event_log_active() -> bool:
- global CHROMIUM_EVENT_LOG
- return CHROMIUM_EVENT_LOG is not None
- @contextmanager
- def chromium_event_timed(
- event_name: str,
- reset_event_log_on_exit: bool = False,
- log_pt2_compile_event: bool = False,
- ) -> Generator[Any, None, None]:
- """
- Context manager that creates a chromium start and end event. Chromium event
- logging is integrated with dynamo_timed, so you probably want to use that
- instead. Use this context manager only if you want to avoid dynamo_timed.
- """
- chromium_event_log = get_chromium_event_logger()
- chromium_start_time = time.time_ns()
- chromium_event_log.log_event_start(
- event_name,
- chromium_start_time,
- {},
- log_pt2_compile_event,
- )
- try:
- yield
- finally:
- chromium_event_log.log_event_end(
- event_name,
- time.time_ns(),
- {},
- chromium_start_time,
- log_pt2_compile_event,
- )
- if reset_event_log_on_exit:
- chromium_event_log.reset()
- @dataclasses.dataclass
- class CleanupHook:
- """Remove a global variable when hook is called"""
- scope: dict[str, Any]
- name: str
- def __call__(self, *args: Any) -> None:
- # Make sure we're not shutting down
- if CleanupManager is not None:
- CleanupManager.count -= 1
- del self.scope[self.name]
- @staticmethod
- def create(scope: dict[str, Any], name: str, val: Any) -> CleanupHook:
- assert name not in scope
- CleanupManager.count += 1
- scope[name] = val
- return CleanupHook(scope, name)
- class CleanupManager(ExactWeakKeyDictionary):
- count = 0
- instance: ClassVar[CleanupManager]
- def _remove_id(self, idx: int) -> None:
- for hook in self.values[idx]:
- hook()
- super()._remove_id(idx)
- CleanupManager.instance = CleanupManager()
- def clone_tensor(x: torch.Tensor) -> torch.Tensor:
- """Clone the tensor and its gradient"""
- y = x.clone().requires_grad_(x.requires_grad)
- if x.is_leaf and x.grad is not None:
- y.grad = x.grad.clone()
- return y
- def _copy_dynamo_attr(src: torch.Tensor, dst: torch.Tensor, attr: str) -> None:
- """Copy a single dynamo attribute from src to dst, or remove it from dst if src doesn't have it."""
- if hasattr(src, attr):
- setattr(dst, attr, getattr(src, attr).copy())
- elif hasattr(dst, attr):
- delattr(dst, attr)
- def copy_dynamo_tensor_attributes(src: torch.Tensor, dst: torch.Tensor) -> None:
- """
- Copy dynamo-specific tensor attributes from src to dst.
- These attributes are used for dynamic shape marking and must be preserved
- when cloning or casting tensors. If src doesn't have an attribute but dst does,
- the attribute is removed from dst.
- """
- _copy_dynamo_attr(src, dst, "_dynamo_dynamic_indices")
- _copy_dynamo_attr(src, dst, "_dynamo_unbacked_indices")
- _copy_dynamo_attr(src, dst, "_dynamo_hint_overrides")
- _copy_dynamo_attr(src, dst, "_dynamo_shape_ids")
- _copy_dynamo_attr(src, dst, "_dynamo_strict_unbacked_indices")
- _copy_dynamo_attr(src, dst, "_dynamo_weak_dynamic_indices")
- def clone_input(
- x: torch.Tensor, *, dtype: Optional[torch.dtype] = None
- ) -> torch.Tensor:
- """copy while preserving strides"""
- # TODO: this is questionable
- if is_fake(x):
- # this func fails on fake tensors in __torch_dispatch__
- return x
- def torch_clone(x: torch.Tensor) -> torch.Tensor:
- y = torch.clone(x)
- if x.is_leaf:
- y.requires_grad_(x.requires_grad)
- if x.is_leaf and x.grad is not None:
- y.grad = clone_input(x.grad, dtype=dtype)
- copy_dynamo_tensor_attributes(x, y)
- return y
- with torch.no_grad():
- if x.device.type == "xla":
- # Access data_ptr() for a xla tensor will cause crash
- return torch_clone(x)
- # Handle sparse storage (no stride).
- if x.layout is torch.sparse_coo:
- return torch.sparse_coo_tensor(
- torch_clone(x._indices()),
- torch_clone(x._values()),
- x.shape,
- is_coalesced=x.is_coalesced(),
- )
- elif is_sparse_compressed(x):
- if x.layout in {torch.sparse_csr, torch.sparse_bsr}:
- compressed_indices = x.crow_indices()
- plain_indices = x.col_indices()
- else:
- compressed_indices = x.ccol_indices()
- plain_indices = x.row_indices()
- return torch.sparse_compressed_tensor(
- torch_clone(compressed_indices),
- torch_clone(plain_indices),
- torch_clone(x.values()),
- x.shape,
- layout=x.layout,
- )
- elif is_traceable_wrapper_subclass(x):
- # Questionable - but this is required to not fail executorch related
- # torchao tests.
- return torch_clone(x)
- needed_size = sum(
- (shape - 1) * stride for shape, stride in zip(x.size(), x.stride())
- )
- if x.is_quantized:
- result = torch.empty_quantized((needed_size + 32,), x)
- else:
- result = torch.empty(
- needed_size + 32, dtype=dtype or x.dtype, device=x.device
- )
- cache_line_offset = (
- (x.data_ptr() - result.data_ptr()) % 32
- ) // x.element_size()
- result.as_strided_(x.size(), x.stride(), cache_line_offset)
- try:
- result.copy_(x.clone())
- if x.is_leaf:
- result.requires_grad_(x.requires_grad)
- if x.is_leaf and x.grad is not None:
- result.grad = clone_input(x.grad, dtype=dtype)
- except RuntimeError:
- # RuntimeError: unsupported operation: more than one element of the written-to
- # tensor refers to a single memory location. Please clone() the tensor before
- # performing the operation.
- return torch_clone(x)
- copy_dynamo_tensor_attributes(x, result)
- return result
- @overload
- def clone_inputs(
- example_inputs: dict[str, Union[T, tuple[T, ...]]],
- ) -> dict[str, list[T]]: ...
- @overload
- def clone_inputs(example_inputs: Sequence[T]) -> list[T]: ...
- def clone_inputs(example_inputs: Any) -> Any:
- res: Union[dict[str, Any], list[Any]]
- if type(example_inputs) is dict:
- res = dict(example_inputs)
- for key, value in res.items():
- if isinstance(value, tuple):
- res[key] = clone_inputs(value)
- else:
- assert isinstance(value, torch.Tensor), type(value)
- res[key] = clone_input(value)
- return res
- res = list(example_inputs)
- for i in range(len(res)):
- if isinstance(res[i], torch.Tensor):
- res[i] = clone_input(res[i])
- return res
- def skip_frame_if_in_functorch_mode(val: torch.Tensor) -> None:
- try:
- val.data_ptr() # will throw for functorch tensors
- except RuntimeError as e:
- from .exc import unimplemented
- # This will be GradTrackingTensor/BatchedTensor/etc
- functorch_subclass_name = re.sub(r"\(.*", "", repr(val))
- unimplemented(
- gb_type="skip frame due to being in functorh mode",
- context="",
- explanation=f"torch.compile cannot be run in context: {functorch_subclass_name}. Skipping frame.",
- hints=[],
- from_exc=e,
- skip_frame=True,
- )
- @contextmanager
- def preserve_rng_state() -> Generator[None, None, None]:
- disable_functorch = torch._C._DisableFuncTorch
- disable_current_modes = torch.utils._python_dispatch._disable_current_modes
- with disable_current_modes(), disable_functorch():
- rng_state = torch.clone(torch.random.get_rng_state())
- skip_frame_if_in_functorch_mode(rng_state)
- if torch.cuda.is_available():
- cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
- if torch.xpu.is_available():
- xpu_rng_state = torch.clone(torch.xpu.get_rng_state())
- try:
- yield
- finally:
- with torch.utils._python_dispatch._disable_current_modes():
- torch.random.set_rng_state(rng_state)
- if torch.cuda.is_available():
- torch.cuda.set_rng_state(cuda_rng_state) # type: ignore[possibly-undefined]
- if torch.xpu.is_available():
- torch.xpu.set_rng_state(xpu_rng_state) # type: ignore[possibly-undefined]
- def is_jit_model(
- model0: Any,
- ) -> TypeIs[
- Union[
- torch.jit._trace.TopLevelTracedModule,
- torch.jit._script.RecursiveScriptModule,
- # pyrefly: ignore [invalid-param-spec]
- torch.jit.ScriptFunction[Any, Any],
- torch.jit.ScriptModule,
- ]
- ]:
- return isinstance(
- model0,
- (
- torch.jit._trace.TopLevelTracedModule,
- torch.jit._script.RecursiveScriptModule,
- torch.jit.ScriptFunction,
- torch.jit.ScriptModule,
- ),
- )
- def torchscript(model: Any, example_inputs: Any, verbose: bool = False) -> Any:
- if is_jit_model(model):
- # already done?
- return model
- try:
- return torch.jit.trace(model, example_inputs)
- except Exception:
- try:
- return torch.jit.script(model)
- except Exception:
- if verbose:
- log.exception("jit error")
- else:
- log.error("Both torch.jit.trace and torch.jit.script failed")
- return None
- def getfile(obj: Any) -> Optional[str]:
- try:
- return inspect.getfile(obj)
- except (TypeError, OSError):
- return None
- def is_namedtuple(obj: Any) -> bool:
- """Test if an object is a namedtuple or a torch.return_types.* quasi-namedtuple"""
- return is_namedtuple_cls(type(obj))
- def is_namedtuple_cls(cls: Any) -> bool:
- """Test if an object is a namedtuple or a (torch.return_types|torch.autograd.forward_ad).* quasi-namedtuple"""
- try:
- if issubclass(cls, tuple):
- module = getattr(cls, "__module__", None)
- if module in ("torch.return_types", "torch.autograd.forward_ad"):
- return True
- if isinstance(getattr(cls, "_fields", None), tuple) and callable(
- getattr(cls, "_make", None)
- ):
- # The subclassing style namedtuple can have an extra base `typing.Generic`
- bases = tuple(t for t in cls.__bases__ if t is not Generic)
- if bases == (tuple,):
- # This is a namedtuple type directly created by `collections.namedtuple(...)`
- return True
- if bases and any(
- (
- # Subclass of namedtuple
- is_namedtuple_cls(t)
- # For subclasses of namedtuple, the __new__ method should not be customized
- and cls.__new__ is t.__new__
- )
- for t in bases
- ):
- return True
- except TypeError:
- pass
- return False
- @functools.lru_cache(1)
- def namedtuple_fields(cls: type) -> tuple[str, ...]:
- """Get the fields of a namedtuple or a torch.return_types.* quasi-namedtuple"""
- if cls is slice:
- return ("start", "stop", "step")
- assert issubclass(cls, tuple)
- if hasattr(cls, "_fields"):
- # normal namedtuples
- return cls._fields
- @dataclasses.dataclass
- class Marker:
- index: int
- # frustrating ones e.g. torch.return_types.max
- assert cls.__module__ == "torch.return_types"
- obj = cls(map(Marker, range(cls.n_fields))) # type: ignore[attr-defined]
- fields: dict[str, int] = {}
- for name in dir(obj):
- if name[0] != "_" and isinstance(getattr(obj, name), Marker):
- fields[name] = getattr(obj, name).index
- assert len(fields) == cls.n_fields # type: ignore[attr-defined]
- return tuple(sorted(fields, key=fields.get)) # type: ignore[arg-type]
- def checkpoint_params(gm: torch.fx.GraphModule) -> Callable[[], None]:
- with torch.no_grad():
- rng_state = torch.clone(torch.random.get_rng_state())
- if torch.cuda.is_available():
- cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
- saved_state = [
- (param, param._version, torch.clone(param))
- # pyrefly: ignore [bad-argument-type]
- for param in itertools.chain(gm.parameters(), gm.buffers())
- ]
- def restore() -> None:
- with torch.no_grad():
- torch.random.set_rng_state(rng_state)
- if torch.cuda.is_available():
- torch.cuda.set_rng_state(cuda_rng_state)
- for param, version, original_value in saved_state:
- if param._version != version:
- param.copy_(original_value)
- return restore
- def timed(
- model: Any, example_inputs: Iterable[Any], times: int = 1
- ) -> tuple[Any, float]:
- if torch.cuda.is_available():
- synchronize = torch.cuda.synchronize
- else:
- synchronize = nothing
- synchronize()
- gc.collect()
- torch.manual_seed(1337)
- t0 = time.perf_counter()
- for _ in range(times):
- result = model(*example_inputs)
- synchronize()
- t1 = time.perf_counter()
- return result, t1 - t0 # type: ignore[possibly-undefined]
- def check_is_cuda(gm: torch.fx.GraphModule, example_inputs: Iterable[Any]) -> bool:
- return all(x.is_cuda for x in itertools.chain(example_inputs, gm.parameters(True)))
- @lru_cache(32)
- def rot_n_helper(n: int) -> Callable[..., Any]:
- assert n > 1
- vars = [f"v{i}" for i in range(n)]
- rotated = reversed(vars[-1:] + vars[:-1])
- fn = eval(f"lambda {','.join(vars)}: ({','.join(rotated)})")
- fn.__name__ = f"rot_{n}_helper"
- return fn
- common_constant_types: set[type] = {
- int,
- float,
- complex,
- bool,
- str,
- bytes,
- type(None),
- Ellipsis.__class__,
- NotImplemented.__class__,
- types.CodeType,
- # Commonly used immutable types from torch.
- torch.device,
- torch.dtype,
- torch.memory_format,
- torch.layout,
- torch.finfo,
- torch.iinfo,
- torch.nn.attention.SDPBackend,
- torch.cuda._CudaDeviceProperties,
- }
- if has_triton_package():
- import triton
- common_constant_types.add(triton.language.dtype)
- """
- Difference between is_safe_constant and common_constant_types.
- * common_constant_types: Constants would be wrapped by VariableBuilder.wrap_literal
- as ConstantVariable.
- * is_safe_constant: Constants can be loaded by LOAD_CONST bytecode.
- """
- def is_safe_constant(v: Any) -> bool:
- if istype(v, (tuple, frozenset)):
- return all(map(is_safe_constant, v))
- return isinstance(
- v,
- (
- enum.Enum,
- type,
- torch.Size,
- typing._GenericAlias, # type: ignore[attr-defined]
- types.GenericAlias,
- ),
- ) or istype(
- v,
- common_constant_types | {slice},
- )
- @functools.cache
- def common_constants() -> set[int]:
- return {
- # We zero-one specialize shapes, so specialize these constants
- # too
- 0,
- 1,
- }
- def is_torch_sym(value: Any) -> TypeGuard[Union[torch.SymBool, torch.SymInt]]:
- return isinstance(value, (torch.SymBool, torch.SymInt)) and not isinstance(
- value.node, torch.nested._internal.nested_int.NestedIntNode
- )
- def is_int_specialization_case(value: Any, source: Any) -> bool:
- from .source import is_from_defaults
- return not TracingContext.get().force_unspec_int_unbacked_size_like and (
- # Assume integers from global variables want to be specialized
- not source.guard_source.is_local()
- # Assume that integers that came from NN modules want to be
- # specialized (as we don't expect users to be changing the
- # NN modules on the fly), unless explicitly disabled
- or (
- source.guard_source.is_specialized_nn_module()
- and not config.allow_unspec_int_on_nn_module
- )
- or (
- source.guard_source.is_unspecialized_builtin_nn_module()
- and not config.allow_unspec_int_on_nn_module
- )
- or (
- source.guard_source.is_unspecialized_nn_module()
- and not config.allow_unspec_int_on_nn_module
- )
- or is_from_defaults(source)
- # TODO: Delete this condition when rollout is done. NB: this
- # condition never evaluates True in open source
- or (
- not justknobs_check("pytorch/dynamo:enable_unspecialize_zero_one_plain_int")
- and value in common_constants()
- )
- )
- def specialize_symnode(arg: Any) -> Any:
- from .variables import ConstantVariable, LazyVariableTracker, SymNodeVariable
- # Guard and specialize
- if isinstance(arg, LazyVariableTracker) and not arg.is_realized():
- # Find if the arg would be realized as SymNodeVariable later on. If yes,
- # realize it and specialize. Else return the arg.
- source = arg.original_source()
- value = arg.original_value()
- is_symnode_vt = is_torch_sym(value) or (
- not config.specialize_int
- and type(value) is int
- and not is_int_specialization_case(value, source)
- )
- if not is_symnode_vt:
- return arg
- if isinstance(arg, SymNodeVariable):
- return ConstantVariable.create(arg.evaluate_expr())
- return arg
- def guard_if_dyn(arg: Any) -> Any:
- from .variables import VariableTracker
- arg = specialize_symnode(arg)
- if isinstance(arg, VariableTracker) and arg.is_python_constant():
- return arg.as_python_constant()
- return arg
- def check_constant_args(args: Iterable[Any], kwargs: Mapping[Any, Any]) -> bool:
- return all(x.is_python_constant() for x in itertools.chain(args, kwargs.values()))
- def check_unspec_python_args(args: Iterable[Any], kwargs: Mapping[Any, Any]) -> bool:
- from .variables import VariableTracker
- from .variables.tensor import UnspecializedPythonVariable
- unspec_count = 0
- for x in itertools.chain(args, kwargs.values()):
- if isinstance(x, UnspecializedPythonVariable):
- unspec_count += 1
- elif not (isinstance(x, VariableTracker) and x.is_python_constant()):
- return False
- return unspec_count > 0
- def check_unspec_or_constant_args(
- args: Iterable[Any], kwargs: Mapping[Any, Any]
- ) -> bool:
- # A fused version of:
- # return check_constant_args(args, kwargs) or check_unspec_python_args(args, kwargs)
- from .variables.tensor import UnspecializedPythonVariable
- for x in itertools.chain(args, kwargs.values()):
- if not (x.is_python_constant() or isinstance(x, UnspecializedPythonVariable)):
- return False
- return True
- def check_numpy_ndarray_args(args: Iterable[Any], kwargs: Mapping[Any, Any]) -> bool:
- from .variables.tensor import NumpyNdarrayVariable
- return any(
- isinstance(x, NumpyNdarrayVariable)
- for x in itertools.chain(args, kwargs.values())
- )
- dict_keys: type[KeysView[Any]] = type({}.keys())
- dict_values: type[ValuesView[Any]] = type({}.values())
- dict_items: type[ItemsView[Any, Any]] = type({}.items())
- odict_values: type[ValuesView[Any]] = type(OrderedDict().values())
- # pyrefly: ignore [bad-assignment]
- tuple_iterator: type[Iterator[Any]] = type(iter(()))
- # pyrefly: ignore [bad-assignment]
- range_iterator: type[Iterator[Any]] = type(iter(range(0)))
- tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined]
- object_new = object.__new__
- dict_new = dict.__new__
- dict_methods = {
- method
- for method in itertools.chain(dict.__dict__.values(), OrderedDict.__dict__.values())
- if callable(method)
- }
- set_methods = {method for method in set.__dict__.values() if callable(method)}
- frozenset_methods = {
- method for method in frozenset.__dict__.values() if callable(method)
- }
- tuple_new = tuple.__new__
- tuple_methods = {method for method in tuple.__dict__.values() if callable(method)}
- list_methods = {method for method in list.__dict__.values() if callable(method)}
- list_getitem = list.__getitem__
- str_methods = {method for method in str.__dict__.values() if callable(method)}
- # EnumType is the metaclass for Enum classes
- enum_type_methods = {
- method for method in type(enum.Enum).__dict__.values() if callable(method)
- }
- K = TypeVar("K")
- V = TypeVar("V")
- def builtin_dict_keys(d: dict[K, V]) -> KeysView[K]:
- # Avoids overridden keys method of the dictionary
- assert isinstance(d, dict)
- return dict.keys(d)
- def get_items_from_dict(obj: dict[K, V]) -> Iterable[tuple[K, Union[V, Any]]]:
- # Get items without calling the user defined __getitem__ or keys method.
- assert isinstance(obj, dict)
- if istype(obj, (dict, OrderedDict)):
- return obj.items()
- elif isinstance(obj, OrderedDict):
- # pyrefly: ignore [bad-argument-type]
- return [(k, OrderedDict.__getitem__(obj, k)) for k in OrderedDict.keys(obj)]
- else:
- # pyrefly: ignore [bad-argument-type]
- return [(k, dict.__getitem__(obj, k)) for k in dict.keys(obj)]
- def nn_module_new(cls: Any) -> Any:
- obj = object_new(cls)
- # pyrefly: ignore [bad-argument-type]
- torch.nn.Module.__init__(obj)
- return obj
- def product(it: Iterable[T]) -> int:
- return functools.reduce(operator.mul, it, 1)
- def tuple_iterator_getitem(it: Any, index: int) -> Any:
- _, (obj,), start = it.__reduce__()
- return obj[start + index]
- def dataclass_fields(cls: Any) -> Any:
- return torch._dynamo.disable(dataclasses.fields)(cls)
- iter_next = next
- def normalize_range_iter(range_iter: Any) -> tuple[int, int, int]:
- _, (range_obj,), maybe_idx = range_iter.__reduce__()
- # In 3.12+, `maybe_idx` could be None, and `range_obj.start` would've been
- # already incremented by the current index.
- # The index (maybe_idx) is the number of steps taken so far. To get the
- # correct start value, one must add (maybe_idx * step) to the original
- # start. See:
- # https://github.com/python/cpython/blob/ea77feecbba389916af8f90b2fc77f07910a2963/Objects/rangeobject.c#L885-L899
- start = range_obj.start + (maybe_idx or 0) * range_obj.step
- stop = range_obj.stop
- step = range_obj.step
- return (start, stop, step)
- def to_subclass(t: Any, cls: type) -> Any:
- return t.as_subclass(cls)
- dict_getitem = dict.__getitem__
- @torch.fx.wrap
- def dict_keys_getitem(d: dict[Any, Any], n: int) -> Any:
- # Call dict(d) to prevent calling overridden __iter__/keys
- dict_class = dict
- if isinstance(d, OrderedDict):
- dict_class = OrderedDict
- # pyrefly: ignore [bad-argument-type]
- return next(itertools.islice(dict_class.keys(d), n, n + 1))
- def set_getitem(s: set[T], n: int) -> T:
- # Set ordering might not be stable
- return list(s)[n]
- def enum_repr(value: Any, local: bool) -> str:
- # enum class can override __str__ method. Use __class__ and name attribute
- # to extract the class name and key name.
- name = value.__class__.__name__
- val = value.name
- scope = "L" if local else "G"
- local_name = f'{scope}["{name}"].{val}'
- return local_name
- def set_example_value(node: torch.fx.Node, example_value: Any) -> None:
- # NB: example_value is a bit of a misnomer, because this is always a fake
- # tensor of some sort. Furthermore, these example values serve as the
- # runtime state of Dynamo tracing, which means if metadata mutation
- # occurs, the example_value gets directly updated (so you can't rely on
- # this to accurately reflect what the state of the value was at the time
- # the program was traced).
- node.meta["example_value"] = example_value
- fake_mode = TracingContext.get().fake_mode
- assert fake_mode is not None
- shape_env = fake_mode.shape_env
- if (
- symbol_to_path
- := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
- shape_env, example_value
- )
- ):
- node.meta["unbacked_bindings"] = symbol_to_path
- def _get_fake_tensor(vt: VariableTracker) -> Any:
- fake_tensor = vt.as_proxy().node.meta.get("example_value")
- if not is_fake(fake_tensor):
- from . import graph_break_hints
- from .exc import unimplemented
- unimplemented(
- gb_type="Cannot check Tensor object identity without its fake value",
- context=str(fake_tensor),
- explanation="TensorVariable is missing a fake example_value.",
- hints=[*graph_break_hints.DYNAMO_BUG],
- )
- return fake_tensor
- def slice_length(s: slice, seq_len: int) -> int:
- start, stop, step = s.indices(seq_len)
- return max(0, (stop - start + (step - (1 if step > 0 else -1))) // step)
- def raise_args_mismatch(
- tx: InstructionTranslatorBase,
- name: str,
- expect: str = "",
- actual: str = "",
- ) -> None:
- from torch._dynamo.exc import raise_observed_exception
- from torch._dynamo.variables import ConstantVariable
- msg_str = (
- f"wrong number of arguments or keyword arguments for {name}() call.\n"
- f" Expect: {expect}\n"
- f" Actual: {actual}"
- )
- raise_observed_exception(
- TypeError,
- tx,
- args=[ConstantVariable(msg_str)],
- )
- def iter_contains(
- items: Iterable[Any],
- search: Any,
- tx: InstructionTranslator,
- check_tensor_identity: bool = False,
- ) -> Any:
- from .variables import ConstantVariable
- if search.is_python_constant():
- found_const = any(
- x.is_python_constant()
- and x.as_python_constant() == search.as_python_constant()
- for x in items
- )
- return ConstantVariable.create(found_const)
- must_check_tensor_id = False
- if check_tensor_identity and search.is_tensor():
- must_check_tensor_id = True
- # Match of Tensor means match of FakeTensor
- search = _get_fake_tensor(search)
- found: Optional[VariableTracker] = None
- for x in items:
- if must_check_tensor_id:
- if x.is_tensor():
- if search is _get_fake_tensor(x): # Object equivalence
- return ConstantVariable.create(True)
- else:
- from torch._dynamo.variables.builder import SourcelessBuilder
- check = SourcelessBuilder.create(tx, operator.eq).call_function(
- tx, [x, search], {}
- )
- if found is None:
- found = check
- else:
- found = SourcelessBuilder.create(tx, operator.or_).call_function(
- tx, [check, found], {}
- )
- if found is None:
- found = ConstantVariable.create(False)
- return found
- def key_is_id(
- k: Any,
- ) -> TypeIs[Union[torch.Tensor, torch.nn.Module, MethodWrapperType]]:
- """Returns whether it indexes dictionaries using its id"""
- return isinstance(k, (torch.Tensor, torch.nn.Module, MethodWrapperType))
- def key_to_id(value: Any) -> list[Any]:
- return [id(k) if key_is_id(k) else k for k in value]
- def const_repr(x: Any, *, local: Any) -> str:
- from .trace_rules import is_builtin_callable
- if isinstance(x, (list, tuple)):
- elems_repr = ",".join(const_repr(s, local=local) for s in x)
- if isinstance(x, list):
- return f"[{elems_repr}]"
- else:
- assert isinstance(x, tuple)
- if len(x) == 1:
- return f"({elems_repr},)"
- else:
- return f"({elems_repr})"
- elif isinstance(x, enum.Enum):
- # To workaround repr(Enum) returning invalid global reference before python 3.11
- # by calling enum_repr and removing quotes to render enum in guard code.
- return enum_repr(x, local=local).replace("'", "")
- elif is_builtin_callable(x):
- return x.__name__
- elif isinstance(x, type):
- def fullname(o: Any) -> str:
- klass = o.__class__
- module = klass.__module__
- if module == "builtins":
- return klass.__qualname__ # avoid outputs like 'builtins.str'
- return module + "." + klass.__qualname__
- return fullname(x)
- else:
- return f"{x!r}"
- def dict_keys_repr(const_keys: Any, *, local: Any) -> str:
- keys_str = ",".join(const_repr(s, local=local) for s in const_keys)
- return "[" + keys_str + "]"
- GLOBAL_KEY_PREFIX = "__dict_key"
- from torch._subclasses import UnsupportedFakeTensorException # noqa: F401
- def get_safe_global_name(tx: InstructionTranslatorBase, root: str, obj: Any) -> str:
- # The global_mangled_class_name should be different for different
- # invocations of torch.compile. Otherwise, we can run into a situation
- # where multiple torch.compile invocations reuse the same global name,
- # but the global's lifetime is tied to the first invocation (and
- # may be deleted when the first torch.compile invocation is deleted)
- # We mangle it based off of the output_graph's id.
- return f"{root}_{id(obj)}_c{tx.output.compile_id}"
- def is_in(item: T, *containers: Container[T]) -> bool:
- for container in containers:
- if item in container:
- return True
- return False
- def get_unique_name_wrt(
- prefix: str, *containers: Any, requires_suffix: bool = False
- ) -> str:
- """
- Return a name that starts with `prefix` and is not in any of the
- `containers` (e.g., map, set).
- """
- if not requires_suffix and not is_in(prefix, *containers):
- return prefix
- for i in itertools.count():
- candidate = f"{prefix}_{i}"
- if not is_in(candidate, *containers):
- return candidate
- raise AssertionError("unreachable")
- def wrap_fake_exception(fn: Callable[[], Any]) -> Any:
- try:
- return fn()
- except UnsupportedFakeTensorException as e:
- from .exc import unimplemented
- msg = f"Encountered exception ({e.reason}) during fake tensor propagation."
- log.warning(msg)
- unimplemented(
- gb_type="Fake tensor propagation exception",
- context=str(e.reason),
- explanation=msg,
- hints=[],
- from_exc=e,
- )
- def deepcopy_to_fake_tensor(
- obj: Any, fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
- ) -> Any:
- with torch._subclasses.fake_tensor.FakeCopyMode(fake_mode):
- return wrap_fake_exception(lambda: copy.deepcopy(obj))
- def rmse(ref: torch.Tensor, res: torch.Tensor) -> torch.Tensor:
- """
- Calculate root mean squared error
- """
- return torch.sqrt(torch.mean(torch.square(ref - res)))
- def bitwise_same(ref: Any, res: Any, equal_nan: bool = False) -> bool:
- return same(
- ref,
- res,
- tol=0.0,
- equal_nan=equal_nan,
- )
- def same(
- ref: Any,
- res: Any,
- fp64_ref: Any = None,
- cos_similarity: bool = False,
- tol: float = 1e-4,
- equal_nan: bool = False,
- exact_dtype: bool = True,
- relax_numpy_equality: bool = False,
- ignore_non_fp: bool = False,
- log_error: Callable[..., None] = log.error,
- use_larger_multiplier_for_smaller_tensor: bool = False,
- force_max_multiplier: bool = False,
- use_iou_for_bool: bool = False,
- iou_threshold: float = 0.99,
- ) -> bool:
- """Check correctness to see if ref and res match"""
- if fp64_ref is None:
- fp64_ref = ref
- if isinstance(
- ref, (list, tuple, collections.deque, torch.nn.ParameterList, torch.Size)
- ):
- assert isinstance(res, (list, tuple, collections.deque)), (
- f"type mismatch {type(ref)} {type(res)}"
- )
- if len(ref) != len(res):
- log_error("Length mismatch")
- return False
- return len(ref) == len(res) and all(
- same(
- ai,
- bi,
- fp64_refi,
- cos_similarity,
- tol,
- equal_nan,
- exact_dtype,
- relax_numpy_equality,
- ignore_non_fp,
- log_error=log_error,
- use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor,
- force_max_multiplier=force_max_multiplier,
- use_iou_for_bool=use_iou_for_bool,
- iou_threshold=iou_threshold,
- )
- for ai, bi, fp64_refi in zip(ref, res, fp64_ref)
- )
- elif type(ref).__name__ == "QuestionAnsweringModelOutput":
- # This skips checking accuracy for start_logits/end_logits.
- # Tentatively, start_logits/end_logits appear to be very prone to
- # inaccuracies and is somewhat subsumed by checking the loss.
- return same(
- ref.loss,
- res.loss,
- fp64_ref.loss,
- cos_similarity,
- tol,
- equal_nan,
- exact_dtype,
- relax_numpy_equality,
- ignore_non_fp,
- log_error=log_error,
- use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor,
- force_max_multiplier=force_max_multiplier,
- use_iou_for_bool=use_iou_for_bool,
- iou_threshold=iou_threshold,
- )
- elif isinstance(ref, dict):
- assert isinstance(res, dict)
- assert set(ref.keys()) == set(res.keys()), (
- f"keys mismatch {set(ref.keys())} == {set(res.keys())}"
- )
- for k in sorted(ref.keys()):
- if not (
- same(
- ref[k],
- res[k],
- fp64_ref[k],
- cos_similarity=cos_similarity,
- tol=tol,
- equal_nan=equal_nan,
- exact_dtype=exact_dtype,
- relax_numpy_equality=relax_numpy_equality,
- ignore_non_fp=ignore_non_fp,
- log_error=log_error,
- use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor,
- force_max_multiplier=force_max_multiplier,
- use_iou_for_bool=use_iou_for_bool,
- iou_threshold=iou_threshold,
- )
- ):
- log_error("Accuracy failed for key name %s", k)
- return False
- return True
- elif isinstance(ref, set):
- assert isinstance(res, set)
- assert set(ref) == set(res), f"elements mismatch {set(ref)} == {set(res)}"
- return True
- elif isinstance(ref, (torch.Tensor, float)):
- assert not isinstance(ref, torch._subclasses.FakeTensor)
- assert not isinstance(res, torch._subclasses.FakeTensor)
- def to_tensor(t: Any) -> torch.Tensor:
- return t if isinstance(t, torch.Tensor) else torch.tensor(t)
- ref, res, fp64_ref = (to_tensor(val) for val in (ref, res, fp64_ref))
- if ref.is_sparse:
- assert res.is_sparse
- ref = ref.to_dense()
- res = res.to_dense()
- assert isinstance(res, torch.Tensor), f"type mismatch {type(ref)} {type(res)}"
- if exact_dtype:
- if ref.dtype != res.dtype:
- log_error("dtype mismatch %s, %s", ref.dtype, res.dtype)
- return False
- if ref.dtype == torch.bool:
- if ignore_non_fp:
- return True
- if use_iou_for_bool:
- # Use IoU (Intersection over Union) metric for boolean mask comparison.
- # This is useful for segmentation models where small floating-point
- # differences get thresholded into boolean masks.
- intersection = (ref & res).sum().float()
- union = (ref | res).sum().float()
- if union == 0:
- # Both masks are empty
- return bool(intersection == 0)
- iou = (intersection / union).item()
- if iou < iou_threshold:
- log_error(
- "IoU accuracy failed: %.4f < %.2f (intersection=%d, union=%d, ref_sum=%d, res_sum=%d, shape=%s)",
- iou,
- iou_threshold,
- int(intersection.item()),
- int(union.item()),
- int(ref.sum().item()),
- int(res.sum().item()),
- list(ref.shape),
- )
- return False
- return True
- # triton stores bool as int8, so add this for more accurate checking
- r = torch.allclose(
- ref.to(dtype=torch.uint8),
- res.to(dtype=torch.uint8),
- atol=tol,
- rtol=tol,
- equal_nan=equal_nan,
- )
- if not r:
- log_error("Accuracy failed: uint8 tensor did not match")
- return r
- if cos_similarity:
- ref = ref.flatten().to(torch.float32)
- res = res.flatten().to(torch.float32)
- if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=True):
- # early exit that handles zero/nan better
- # cosine_similarity(zeros(10), zeros(10), dim=0) is 0
- return True
- score = torch.nn.functional.cosine_similarity(ref, res, dim=0, eps=1e-6)
- if score < 0.99:
- log.warning("Similarity score=%s", score.detach().cpu().item())
- return bool(score >= 0.99)
- else:
- if not exact_dtype:
- ref = ref.to(res.dtype)
- # First try usual allclose
- if torch.allclose(ref, res, atol=tol, rtol=tol, equal_nan=equal_nan):
- return True
- # Check error from fp64 version
- if fp64_ref.dtype == torch.float64:
- # Fix a corner case that res and fp64_ref does not contains NaN and match (with loose tolerance)
- # while the ref contains NaN. In this case, RMSE should not match any ways.
- # But res is 'BETTER' than ref so we count it pass.
- #
- # This happens for Super_SloMo when loop ordering after fusion is enabled:
- # https://gist.github.com/shunting314/11f235c70f7db0d52718d26f4a701cab
- loose_tol = 1e-2 * 4
- if (
- not fp64_ref.isnan().any()
- and not res.isnan().any()
- and ref.isnan().any()
- and torch.allclose(
- fp64_ref.to(dtype=res.dtype),
- res,
- atol=loose_tol,
- rtol=loose_tol,
- equal_nan=equal_nan,
- )
- ):
- return True
- ref_error = rmse(fp64_ref, ref).item()
- # ref unable to produce this with stable numerics in this precision, ignore
- if math.isnan(ref_error):
- log.warning(
- "Found nan in reference. Consider running in higher precision."
- )
- res_error = rmse(fp64_ref, res).item()
- def get_multiplier() -> float:
- # In some particular cases, we expect high difference in results.
- # At the moment one of this cases is inductor freezing bfloat16 convolution const folding.
- # In case of it the res_error is at least one order of magnitude higher.
- if force_max_multiplier:
- return 10.0
- # In the case of using AMP (Automatic Mixed Precision), certain models have
- # failed the benchmark's correctness check. However, the end-to-end model's
- # accuracy when comparing AMP with FP32 is within a difference of less than 0.1%.
- # Thus, it's possible that the correctness check failures for these models are
- # false alarms. We use multiplier of 3 instead of 2 to avoid these false alarms.
- multiplier = (
- 3.0 if res.dtype in (torch.float16, torch.bfloat16) else 2.0
- )
- if use_larger_multiplier_for_smaller_tensor and (
- fp64_ref.numel() <= 10
- ):
- multiplier = 10.0
- elif use_larger_multiplier_for_smaller_tensor and (
- fp64_ref.numel() <= 500
- ):
- multiplier = 8.0
- elif (
- fp64_ref.numel() < 1000
- or (ref.ndim == 4 and ref.shape[-1] == ref.shape[-2] == 1)
- # large tol means a benchmark has been specified as REQUIRE_HIGHER_TOLERANCE
- or tol >= 2 * 1e-2
- ):
- # In the presence of noise, noise might dominate our error
- # metric for smaller tensors.
- # Similarly, for 1x1 kernels, there seems to be high noise with amp.
- multiplier = 3.0
- return multiplier
- multiplier = get_multiplier()
- passes_test = res_error <= (multiplier * ref_error + tol / 10.0)
- if (
- not passes_test
- and equal_nan
- and math.isnan(ref_error)
- and math.isnan(res_error)
- # Some unit test for the accuracy minifier relies on
- # returning false in this case.
- and not torch._inductor.config.cpp.inject_relu_bug_TESTING_ONLY
- ):
- passes_test = True
- if not passes_test:
- log_error(
- "RMSE (res-fp64): %.5f, (ref-fp64): %.5f and shape=%s. res.dtype: %s, multiplier: %f, tol: %f"
- ", use_larger_multiplier_for_smaller_tensor: %d",
- res_error,
- ref_error,
- res.size(),
- res.dtype,
- multiplier,
- tol,
- use_larger_multiplier_for_smaller_tensor,
- )
- return passes_test
- if ignore_non_fp:
- return True
- log_error("Accuracy failed: allclose not within tol=%s", tol)
- return False
- elif isinstance(ref, (str, int, type(None), bool, torch.device)):
- if ignore_non_fp:
- return True
- r = ref == res
- if not r:
- log_error("Accuracy failed (%s): %s != %s", type(ref), ref, res)
- return r
- elif is_numpy_int_type(ref) or is_numpy_float_type(ref):
- if relax_numpy_equality and not (
- is_numpy_int_type(res) or is_numpy_float_type(res)
- ):
- ref = ref.item()
- r = (type(ref) is type(res)) and (ref == res)
- if not r:
- log_error("Accuracy failed (numpy): %s != %s", ref, res)
- return r
- elif is_numpy_ndarray(ref):
- return (type(ref) is type(res)) and same(
- torch.as_tensor(ref),
- torch.as_tensor(res),
- fp64_ref,
- cos_similarity=cos_similarity,
- tol=tol,
- equal_nan=equal_nan,
- exact_dtype=exact_dtype,
- relax_numpy_equality=relax_numpy_equality,
- ignore_non_fp=ignore_non_fp,
- log_error=log_error,
- use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor,
- )
- elif type(ref).__name__ in (
- "MaskedLMOutput",
- "Seq2SeqLMOutput",
- "CausalLMOutputWithCrossAttentions",
- "LongformerMaskedLMOutput",
- "Instances",
- "SquashedNormal",
- "Boxes",
- "Normal",
- "TanhTransform",
- "Foo",
- "Variable",
- ):
- assert type(ref) is type(res)
- return all(
- same(
- getattr(ref, key),
- getattr(res, key),
- getattr(fp64_ref, key),
- cos_similarity=cos_similarity,
- tol=tol,
- equal_nan=equal_nan,
- exact_dtype=exact_dtype,
- relax_numpy_equality=relax_numpy_equality,
- ignore_non_fp=ignore_non_fp,
- log_error=log_error,
- use_larger_multiplier_for_smaller_tensor=use_larger_multiplier_for_smaller_tensor,
- )
- for key in ref.__dict__
- )
- else:
- raise RuntimeError(f"unsupported type: {type(ref).__name__}")
- def format_func_info(code: CodeType) -> str:
- short_filename = code.co_filename.split("/")[-1]
- return f"'{code.co_name}' ({short_filename}:{code.co_firstlineno})"
- @contextlib.contextmanager
- def disable_cache_limit() -> Generator[None, None, None]:
- prior = config.recompile_limit
- # pyrefly: ignore [bad-assignment]
- config.recompile_limit = sys.maxsize
- prior_acc_limit = config.accumulated_recompile_limit
- # pyrefly: ignore [bad-assignment]
- config.accumulated_recompile_limit = sys.maxsize
- try:
- yield
- finally:
- config.recompile_limit = prior
- config.accumulated_recompile_limit = prior_acc_limit
- # map from transformed code back to original user code
- orig_code_map = ExactWeakKeyDictionary()
- # keep a record of code_obj -> list of guard failure reasons for logging
- guard_failures: collections.defaultdict[Any, list[Any]] = collections.defaultdict(list)
- # Keep a record of graph break reasons for logging
- graph_break_reasons: list[torch._dynamo.output_graph.GraphCompileReason] = []
- # keep record of compiled code, if we are in "error if recompile"
- # to track code that dynamo has compiled previously
- seen_code_map = ExactWeakKeyDictionary()
- # return same dir unless user changes config between calls
- @functools.cache
- def _get_debug_dir(root_dir: str) -> str:
- dir_name = (
- "run_"
- + datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
- # use pid to avoid conflicts among ranks
- + "-pid_"
- + str(os.getpid())
- )
- return os.path.join(root_dir, dir_name)
- def get_debug_dir() -> str:
- debug_root = config.debug_dir_root
- return _get_debug_dir(debug_root)
- def extract_fake_example_value(node: torch.fx.Node, required: bool = True) -> Any:
- if "example_value" in node.meta and is_fake(node.meta["example_value"]):
- return node.meta["example_value"]
- elif required:
- from torch._dynamo.exc import unimplemented
- from . import graph_break_hints
- unimplemented(
- gb_type="Missing FakeTensor example value",
- context=str(node),
- explanation=f"`FakeTensor` example value was required for {node} but not available.",
- hints=[*graph_break_hints.DYNAMO_BUG],
- )
- else:
- return None
- def ensure_graph_fake(e: Any, tx: InstructionTranslatorBase) -> Any:
- assert maybe_get_fake_mode(e) is tx.fake_mode
- return e
- def get_fake_values_from_nodes(
- tx: InstructionTranslatorBase, nodes: Any, allow_non_graph_fake: bool
- ) -> Any:
- def visit(n: torch.fx.Node) -> Any:
- if n.op == "call_function" and "example_value" not in n.meta:
- # fake tensor validity is checked inside get_fake_value using
- # ensure_graph_fake
- return get_fake_value(n, tx, allow_non_graph_fake)
- elif n.op == "get_attr" and "example_value" not in n.meta:
- assert n.target in tx.output.nn_modules
- gm = tx.output.nn_modules[n.target] # type: ignore[index]
- assert isinstance(gm, torch.fx.GraphModule)
- return gm
- out = n.meta["example_value"]
- if not allow_non_graph_fake and isinstance(out, torch.Tensor):
- return ensure_graph_fake(out, tx)
- return out
- return torch.fx.node.map_arg(nodes, visit)
- def get_concrete_sizes_from_symints(
- msg: str, fake_mode: Optional[FakeTensorMode]
- ) -> str:
- """
- Replace symbolic size expressions (like 's0', 's94') in error messages
- with their concrete runtime values for better readability.
- Example: "size (s94)" -> "size (s94: hint= 10)" if s94's value is 10.
- """
- import re
- from sympy.core.numbers import Integer
- if fake_mode is None:
- return msg
- pattern = r"\(s(\d+)\)"
- assert fake_mode.shape_env is not None
- shape_env = fake_mode.shape_env
- backed_var_to_val = shape_env.backed_var_to_val
- def replace_sym(match: Any) -> str:
- sym_name = f"s{match.group(1)}"
- val = next(
- (v for k, v in backed_var_to_val.items() if k.name == sym_name),
- None,
- )
- if isinstance(val, (int, Integer)):
- return f"({sym_name}: hint = {str(val)})"
- return match.group(0)
- msg = re.sub(pattern, replace_sym, msg)
- return msg
- def _wrap_graph_break_with_torch_runtime_err(gb_fn: Callable[[], NoReturn]) -> NoReturn:
- from .exc import TorchRuntimeError, Unsupported
- try:
- gb_fn()
- except Unsupported as e:
- exc = TorchRuntimeError(str(e), getattr(e, "real_stack", None))
- raise exc.with_traceback(e.__traceback__) from None
- raise AssertionError("should be unreachable")
- def get_fake_value(
- node: torch.fx.Node,
- tx: InstructionTranslatorBase,
- allow_non_graph_fake: bool = False,
- ) -> Any:
- _t0 = time.time_ns()
- try:
- return _get_fake_value_impl(node, tx, allow_non_graph_fake)
- finally:
- tx.output.bytecode_tracing_timings.get_fake_value_ns += time.time_ns() - _t0
- def _get_fake_value_impl(
- node: torch.fx.Node,
- tx: InstructionTranslatorBase,
- allow_non_graph_fake: bool = False,
- ) -> Any:
- """
- Run the computation represented by `node` using fake tensors and return the result.
- allow_non_graph_fake: whether to allow the return result to be:
- 1. non-fake or 2. fake that is not created by this instance of Dynamo.
- If `True`, you must be prepared to deal with such return values, ideally
- by further wrapping them as this graph's fakes.
- """
- from torch.utils._sympy.value_ranges import ValueRangeError
- from . import graph_break_hints
- from .exc import unimplemented, Unsupported, UserError, UserErrorType
- op = node.op
- # FX Node should always return the same fake value
- if "example_value" in node.meta and is_fake(node.meta["example_value"]):
- return node.meta["example_value"]
- args, kwargs = get_fake_values_from_nodes(
- tx, (node.args, node.kwargs), allow_non_graph_fake
- )
- if (
- torch._dynamo.config.use_graph_deduplication
- or torch._dynamo.config.track_nodes_for_deduplication
- ):
- flat_args_kwargs = get_fake_values_from_nodes(
- tx, _get_flat_args(node, {}), allow_non_graph_fake
- )
- id_to_initial_version = {
- id(arg): arg._version for arg in flat_args_kwargs if is_fake(arg)
- }
- else:
- # pyrefly: ignore [implicit-any]
- flat_args_kwargs = []
- # pyrefly: ignore [implicit-any]
- id_to_initial_version = {}
- nnmodule = None
- fake_mode = tx.fake_mode
- assert fake_mode is not None
- if op == "call_method" and len(args) > 0 and isinstance(args[0], torch.nn.Module):
- # If the first argument is nn.Module, should copy to fake mode.
- args = (deepcopy_to_fake_tensor(args[0], fake_mode),) + tuple(args[1:])
- if op == "call_module":
- nnmodule = tx.output.nn_modules[node.target] # type: ignore[index]
- if is_lazy_module(nnmodule) and hasattr(nnmodule, "_initialize_hook"):
- # In the case of a lazy module, we want to run
- # the pre-hooks which initialize it.
- # Afterwards, lazy module deletes its pre-hooks
- # to avoid treating it as lazy on subsequent recompile.
- nnmodule._infer_parameters(nnmodule, args)
- # no matter it's lazy module or not, we should copy to fake mode.
- nnmodule = deepcopy_to_fake_tensor(nnmodule, fake_mode)
- if node.name in ["interpolate", "is_integer", "wrapped_gradient"] or any(
- isinstance(a, complex) for a in args
- ):
- # We need to specialize symfloats for now. Eventually we should do a tensorify pass in dynamo.
- args = tuple(
- (
- float(arg)
- if isinstance(arg, torch.SymFloat) and arg.node.hint is not None
- else arg
- )
- for arg in args
- )
- try:
- with fake_mode, enable_python_dispatcher():
- ret_val = wrap_fake_exception(
- lambda: run_node(tx.output, node, args, kwargs, nnmodule)
- )
- except Unsupported:
- raise
- except RuntimeError as e:
- cause: BaseException = e
- if e.__cause__ is not None:
- cause = e.__cause__
- if isinstance(
- cause, torch._subclasses.fake_tensor.DataDependentOutputException
- ):
- # capture_scalar_outputs only works for these ops right now
- # see torch/_subclasses/fake_impls.py
- if cause.func in (
- torch.ops.aten.item.default,
- torch.ops.aten._local_scalar_dense.default,
- ):
- # does this actually get triggered?
- hints = [
- "Enable tracing of data-dependent output operators with "
- "`torch._dynamo.config.capture_scalar_outputs = True`",
- ]
- else:
- hints = [
- "Consider wrapping the operator into a PyTorch-understood custom operator "
- "(see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html)",
- ]
- unimplemented(
- gb_type="Data dependent operator",
- context=str(cause.func),
- explanation=f"Operator `{cause.func}` has a non-Tensor output "
- "whose value is dependent on the data of Tensor inputs.",
- hints=hints,
- from_exc=cause,
- )
- elif isinstance(
- cause, torch._subclasses.fake_tensor.DynamicOutputShapeException
- ):
- if not torch._dynamo.config.capture_dynamic_output_shape_ops:
- unimplemented(
- gb_type="Dynamic shape operator",
- context=str(cause.func),
- explanation=f"Operator `{cause.func}`'s output shape depends on input Tensor data.",
- hints=[
- "Enable tracing of dynamic shape operators with "
- "`torch._dynamo.config.capture_dynamic_output_shape_ops = True`",
- ],
- from_exc=cause,
- )
- else:
- unimplemented(
- gb_type="Dynamic shape operator (no meta kernel)",
- context=str(cause.func),
- explanation=f"Operator `{cause.func}` does not have a meta kernel that supports dynamic output shapes",
- hints=[
- "Please report an issue to PyTorch",
- ],
- from_exc=cause,
- )
- elif isinstance(
- cause, torch._subclasses.fake_tensor.UnsupportedOperatorException
- ):
- op = cause.func # type: ignore[assignment]
- import_suggestion = ""
- if isinstance(op, torch._ops.OpOverload):
- maybe_pystub = torch._C._dispatch_pystub(
- op._schema.name, op._schema.overload_name
- )
- if maybe_pystub is not None:
- module, ctx = maybe_pystub
- import_suggestion = (
- f"It's possible that the support was implemented in "
- f"module `{module}` and you may need to `import {module}`"
- f"({ctx}), otherwise "
- )
- unimplemented(
- gb_type="Operator does not support running with fake tensors",
- context=f"unsupported operator: {cause.func}",
- explanation="",
- hints=[
- f"{import_suggestion}see "
- "https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0"
- " for how to fix",
- ],
- from_exc=cause,
- )
- elif isinstance(
- cause, torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode
- ):
- raise UserError( # noqa: B904
- UserErrorType.CONSTRAINT_VIOLATION,
- str(cause),
- case_name="constrain_as_size_example",
- )
- elif isinstance(cause, ValueRangeError):
- raise UserError(UserErrorType.CONSTRAINT_VIOLATION, e.args[0]) from e
- elif isinstance(cause, TypeError) and "argument" in str(cause):
- unimplemented(
- gb_type="TypeError when making fake tensor call",
- context=f"TypeError {node.target}: {cause}",
- explanation="",
- hints=[*graph_break_hints.USER_ERROR],
- from_exc=cause,
- )
- msg = get_concrete_sizes_from_symints(str(e), fake_mode)
- _wrap_graph_break_with_torch_runtime_err(
- lambda: unimplemented(
- gb_type="RuntimeError when making fake tensor call",
- context="",
- explanation=msg,
- hints=[*graph_break_hints.USER_ERROR],
- from_exc=cause,
- )
- )
- raise AssertionError("should not reachable") from None
- if not allow_non_graph_fake:
- _ = pytree.tree_map_only(
- torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val
- )
- if (
- torch._dynamo.config.use_graph_deduplication
- or torch._dynamo.config.track_nodes_for_deduplication
- ):
- tx.output.region_tracker.track_node_mutations(
- node,
- flat_args_kwargs,
- id_to_initial_version,
- )
- return ret_val
- _current_node = threading.local()
- def get_current_node() -> Optional[torch.fx.Node]:
- return getattr(_current_node, "value", None)
- @contextmanager
- def set_current_node(node: torch.fx.Node) -> Generator[None, None, None]:
- old = get_current_node()
- _current_node.value = node
- try:
- yield
- finally:
- _current_node.value = old
- def run_node(
- tracer: Any, node: torch.fx.Node, args: Any, kwargs: Any, nnmodule: Any
- ) -> Any:
- """
- Runs a given node, with the given args and kwargs.
- Behavior is dictated by a node's op.
- run_node is useful for extracting real values out of nodes.
- See get_real_value for more info on common usage.
- Note: The tracer arg is only used for 'get_attr' ops
- Note: The nnmodule arg is only used for 'call_module' ops
- Nodes that are not call_function, call_method, call_module, or get_attr will
- raise an AssertionError.
- """
- op = node.op
- with set_current_node(node):
- def make_error_message(e: Any) -> str:
- return (
- f"Dynamo failed to run FX node with fake tensors: {op} {node.target}(*{args}, **{kwargs}): got "
- + repr(e)
- )
- from .exc import Unsupported
- try:
- if op == "call_function":
- return node.target(*args, **kwargs) # type: ignore[operator]
- elif op == "call_method":
- if not hasattr(args[0], node.target): # type: ignore[arg-type]
- from . import graph_break_hints
- from .exc import unimplemented
- unimplemented(
- gb_type="Missing attribute when running call_method node",
- context="",
- explanation=make_error_message("attribute not defined"),
- hints=[*graph_break_hints.USER_ERROR],
- )
- return getattr(args[0], node.target)(*args[1:], **kwargs) # type: ignore[arg-type]
- elif op == "call_module":
- assert nnmodule is not None
- return nnmodule(*args, **kwargs)
- elif op == "get_attr":
- return tracer.output_graph.get_submodule(node.target)
- elif op == "placeholder":
- assert "example_value" in node.meta
- return node.meta["example_value"]
- except (NotImplementedError, UnsupportedFakeTensorException) as e:
- # NB: mimic how wrap_fake_exception does it
- from . import graph_break_hints
- from .exc import unimplemented
- hints = [*graph_break_hints.USER_ERROR]
- if isinstance(e, NotImplementedError):
- hints += [
- "If the op is a custom op, did you implement a fake tensor implementation? "
- "(e.g. with `@my_custom_op.register_fake`)",
- "If the op is a PyTorch op, please file an issue to PyTorch.",
- ]
- unimplemented(
- gb_type="NotImplementedError/UnsupportedFakeTensorException when running FX node",
- context="",
- explanation=make_error_message(e),
- hints=hints,
- from_exc=e,
- )
- except Unsupported:
- raise
- except Exception as e:
- raise RuntimeError(make_error_message(e)).with_traceback(
- e.__traceback__
- ) from e
- raise AssertionError(op)
- def get_real_value(node: torch.fx.Node, tracer: Any) -> Any:
- """
- Run the actual computation represented by `node` and return the result.
- This will execute any dependent nodes in the graph as well.
- """
- from . import graph_break_hints
- from .exc import unimplemented
- cache = tracer.real_value_cache
- if node in cache:
- return cache[node]
- op = node.op
- args, kwargs = torch.fx.node.map_arg( # type: ignore[misc]
- (node.args, node.kwargs),
- lambda n: get_real_value(n, tracer),
- )
- if op == "placeholder" and "grapharg" in node.meta:
- return node.meta["grapharg"].example
- if op == "call_module":
- nn_module = tracer.output_graph.nn_modules[node.target]
- if not is_lazy_module(nn_module):
- nn_module = copy.deepcopy(nn_module)
- else:
- # In the case of a lazy module, we want to run
- # the pre-hooks which initialize it
- nn_module(*args, **kwargs)
- else:
- nn_module = None
- try:
- real_value = run_node(tracer, node, args, kwargs, nn_module)
- cache[node] = real_value
- except RuntimeError as e:
- exn = e # to make typing happy for the lambda
- _wrap_graph_break_with_torch_runtime_err(
- lambda: unimplemented(
- gb_type="RuntimeError when trying to get real value from fx.Node",
- context="",
- explanation="",
- hints=[*graph_break_hints.USER_ERROR],
- from_exc=exn,
- )
- )
- raise AssertionError("should not be reachable") from None
- return real_value
- def assert_no_fake_params_or_buffers(gm: torch.nn.Module) -> None:
- from torch._subclasses.fake_tensor import FakeTensorConfig, is_fake
- def stack_or_hint(t: Any) -> str:
- if FakeTensorConfig.debug:
- import traceback
- return f"FAKE TENSOR CREATION TRACEBACK: \n {traceback.format_list(t._debug_trace)}"
- else:
- return "Enable TORCH_FAKE_TENSOR_DEBUG=1 to get creation stack traces on fake tensors."
- for name, buffer in gm.named_buffers():
- assert not is_fake(buffer), (
- f"Unexpected fake buffer {name} {stack_or_hint(buffer)}"
- )
- for name, param in gm.named_parameters():
- assert not is_fake(param), (
- f"Unexpected fake param {name} {stack_or_hint(param)}"
- )
- def fqn(obj: Any) -> str:
- """
- Returns the fully qualified name of the object.
- """
- return f"{obj.__module__}.{obj.__qualname__}"
- def ifdynstaticdefault(count1: Any, count2: Any) -> Any:
- if torch._dynamo.config.assume_static_by_default:
- return count1
- else:
- return count2
- def import_submodule(mod: types.ModuleType) -> None:
- """
- Ensure all the files in a given submodule are imported
- """
- for filename in sorted(os.listdir(os.path.dirname(cast(str, mod.__file__)))):
- if filename.endswith(".py") and filename[0] != "_":
- importlib.import_module(f"{mod.__name__}.{filename[:-3]}")
- def object_has_getattribute(value: Any) -> bool:
- return class_has_getattribute(type(value))
- def object_setattr_ignore_descriptor(obj: Any, name: str, value: Any) -> None:
- # https://github.com/python/cpython/blob/3.11/Objects/object.c#L1286-L1335
- d = object.__getattribute__(obj, "__dict__")
- d[name] = value
- def class_has_getattribute(cls: type) -> bool:
- try:
- if isinstance(
- inspect.getattr_static(cls, "__getattribute__"),
- types.FunctionType,
- ):
- return True
- except AttributeError:
- pass
- return False
- def get_custom_getattr(
- value: Any, ignore_nn_module_getattr: bool = False
- ) -> Optional[Any]:
- try:
- getattr_fn = inspect.getattr_static(type(value), "__getattr__")
- except AttributeError:
- getattr_fn = None
- if ignore_nn_module_getattr and getattr_fn is torch.nn.Module.__getattr__:
- # ignore this case of getattr
- getattr_fn = None
- return getattr_fn
- class TensorStaticReason(enum.Enum):
- PARAMETER = 2
- NOT_TENSOR = 4
- NN_MODULE_PROPERTY = 5
- def tensor_static_reason_to_message(reason: TensorStaticReason) -> str:
- if reason == TensorStaticReason.PARAMETER:
- return "mark_dynamic on parameter, parameters are always static today."
- if reason == TensorStaticReason.NOT_TENSOR:
- return "mark_dynamic on a non tensor, how did this happen?"
- if reason == TensorStaticReason.NN_MODULE_PROPERTY:
- return "tensor is static because it is nn module associated."
- raise AssertionError(f"Illegal reason {reason}")
- def tensor_always_has_static_shape(
- tensor: Union[torch.Tensor, Any],
- is_tensor: bool,
- tensor_source: Source,
- ) -> tuple[bool, Optional[TensorStaticReason]]:
- """
- Given a tensor, source, and is_tensor flag, determine if a shape should be static.
- Args:
- tensor - the real tensor to evaluate, parameters force a static shape.
- is_tensor - internal dynamo check, essentially "is_tensor": target_cls is TensorVariable,
- tensors not in a TensorVariable for whatever reason are forced static.
- Returns a tuple, where the first element is the bool of whether or not this tensor should have a static shape.
- The second element is a TensorStaticReason, useful for passing to tensor_static_reason_to_message if needed.
- """
- from .source import is_from_unspecialized_param_buffer_source
- if (
- tensor_source.guard_source.is_specialized_nn_module()
- or tensor_source.guard_source.is_unspecialized_builtin_nn_module()
- ) and config.force_nn_module_property_static_shapes:
- return True, TensorStaticReason.NN_MODULE_PROPERTY
- if (
- type(tensor) is torch.nn.Parameter
- or is_from_unspecialized_param_buffer_source(tensor_source)
- ) and config.force_parameter_static_shapes:
- return True, TensorStaticReason.PARAMETER
- if not is_tensor:
- return True, TensorStaticReason.NOT_TENSOR
- return False, None
- def lazy_format_graph_tabular(fn_name: str, gm: torch.fx.GraphModule) -> Any:
- def inner() -> str:
- try:
- from tabulate import tabulate # TODO: Check that this is installed
- except ImportError:
- return (
- "Tabulate module missing, please install tabulate to log the graph in tabular format, logging code instead:\n"
- + str(lazy_format_graph_code(fn_name, gm))
- )
- node_specs = [
- [n.op, n.name, n.target, n.args, n.kwargs] for n in gm.graph.nodes
- ]
- graph_str = tabulate(
- node_specs, headers=["opcode", "name", "target", "args", "kwargs"]
- )
- return _format_graph_code(fn_name, gm.forward.__code__.co_filename, graph_str)
- return LazyString(inner)
- def format_bytecode(
- prefix: str, name: str, filename: str, line_no: int, code: Any
- ) -> str:
- return f"{prefix} {name} {filename} line {line_no} \n{dis.Bytecode(code).dis()}\n"
- forward_hook_names = [
- "_forward_pre_hooks",
- "_forward_pre_hooks_with_kwargs",
- "_forward_hooks_with_kwargs",
- "_forward_hooks",
- ]
- backward_hook_names = ["_backward_pre_hooks", "_backward_hooks"]
- state_dict_hook_names = [
- "_state_dict_pre_hooks",
- "_state_dict_hooks",
- "_load_state_dict_pre_hooks",
- "_load_state_dict_post_hooks",
- ]
- all_hook_names = forward_hook_names + backward_hook_names + state_dict_hook_names
- def nn_module_has_global_hooks() -> bool:
- # This is limited to backward hooks for now because NNModuleVariable
- # supports fwd hooks underneath.
- return bool(
- len(torch.nn.modules.module._global_backward_hooks)
- or len(torch.nn.modules.module._global_backward_pre_hooks)
- )
- def nn_module_get_all_hooks(
- mod: torch.nn.Module,
- check_forward_hooks: bool = False,
- check_backward_hooks: bool = False,
- check_state_dict_hooks: bool = False,
- ) -> list[Any]:
- """
- Sometimes its useful to differentiate between types of hooks such as forward/backward/pre
- hooks executed during module.__call__, and state_dict hooks which are executed separately.
- """
- hook_dicts_to_check = []
- check_all_hooks = (
- not check_forward_hooks
- and not check_backward_hooks
- and not check_state_dict_hooks
- )
- if check_forward_hooks or check_all_hooks:
- hook_dicts_to_check.extend(forward_hook_names)
- if check_backward_hooks or check_all_hooks:
- hook_dicts_to_check.extend(backward_hook_names)
- if check_state_dict_hooks:
- hook_dicts_to_check.extend(state_dict_hook_names)
- all_hooks = []
- for hook_dict_name in hook_dicts_to_check:
- hooks = getattr(mod, hook_dict_name, [])
- for hook_name in hooks:
- hook = hooks[hook_name]
- all_hooks.append(hook)
- return all_hooks
- def nnmodule_has_hooks(
- mod: torch.nn.Module,
- check_forward_hooks: bool = False,
- check_backward_hooks: bool = False,
- check_state_dict_hooks: bool = False,
- ) -> bool:
- """
- Helper function to check if a module has any hooks attached to it.
- """
- hooks = nn_module_get_all_hooks(
- mod,
- check_forward_hooks=check_forward_hooks,
- check_backward_hooks=check_backward_hooks,
- check_state_dict_hooks=check_state_dict_hooks,
- )
- return bool(hooks)
- def to_numpy_helper(value: Any) -> Any:
- """Convert tensor and tnp.ndarray to numpy.ndarray."""
- if is_fake(value):
- return value
- if isinstance(value, tnp.ndarray):
- return to_numpy_helper(value.tensor)
- elif isinstance(value, torch.Tensor):
- return value.numpy(force=True)
- elif isinstance(value, (tuple, list)):
- return type(value)(to_numpy_helper(obj) for obj in value)
- else:
- return value
- def numpy_to_tensor(value: Any) -> Any:
- """Convert tnp.ndarray to tensor, leave other types intact. If a list/tuple, loop through it to convert."""
- assert np is not None
- if isinstance(value, np.ndarray):
- return torch.as_tensor(value)
- if isinstance(value, tnp.ndarray):
- return value.tensor
- elif isinstance(value, (tuple, list)):
- return type(value)(numpy_to_tensor(obj) for obj in value)
- else:
- return value
- class numpy_to_tensor_wrapper(Generic[_P, R]):
- def __init__(self, f: Callable[_P, R]) -> None:
- self.f = f
- self.__name__ = "wrapped_" + self.f.__name__
- def __repr__(self) -> str:
- return f"<Wrapped function <original {self.f.__name__}>>"
- def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any:
- out = self.f(*args, **kwargs)
- return numpy_to_tensor(out)
- def numpy_attr_wrapper(obj: Any, name: str) -> Any:
- if isinstance(obj, tnp.ndarray):
- out = getattr(obj, name)
- return numpy_to_tensor(out)
- elif isinstance(obj, torch.Tensor):
- out = getattr(tnp.ndarray(obj), name)
- return numpy_to_tensor(out)
- class numpy_method_wrapper:
- """Convert obj from torch.Tensor to tnp.ndarray and call method. Then convert result back to torch.Tensor."""
- def __init__(self, method: str) -> None:
- self.method = method
- self.__name__ = "wrapped_" + self.method
- def __repr__(self) -> str:
- return f"<Wrapped method <original {self.method}>>"
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
- obj = args[0]
- if isinstance(obj, torch.Tensor):
- obj = tnp.ndarray(obj)
- method_callable = getattr(obj, self.method)
- out = method_callable(*args[1:], **kwargs)
- return numpy_to_tensor(out)
- class numpy_operator_wrapper(Generic[_P, R]):
- """Implements dunder methods for tnp.ndarray via functions from the operator library"""
- def __init__(self, op: Callable[..., Any]) -> None:
- self.op = op
- self.__name__ = f"wrapped_{op.__name__}"
- def __repr__(self) -> str:
- return f"<Wrapped operator <original {self.__name__}>>"
- def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> Any:
- assert not kwargs
- # pyrefly: ignore [bad-assignment]
- args = (
- tnp.ndarray(arg) if isinstance(arg, torch.Tensor) else arg for arg in args
- )
- out = self.op(*args)
- return numpy_to_tensor(out)
- def defake(x: Any) -> Any:
- if not isinstance(x, FakeTensor):
- return x
- size: torch._prims_common.ShapeType
- stride: torch._prims_common.StrideType
- if x._has_symbolic_sizes_strides:
- size = []
- for s in x.size():
- if isinstance(s, torch.SymInt):
- size.append(s.node.shape_env.size_hint(s.node.expr))
- else:
- size.append(s)
- stride = []
- for s in x.stride():
- if isinstance(s, torch.SymInt):
- stride.append(s.node.shape_env.size_hint(s.node.expr))
- else:
- stride.append(s)
- else:
- size = x.size()
- stride = x.stride()
- y = torch.empty_strided(
- size,
- stride,
- dtype=x.dtype,
- device=x.device,
- requires_grad=x.requires_grad,
- )
- y.zero_()
- return y
- def _disable_side_effect_safety_checks_for_current_subtracer(
- fn: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs
- ) -> R:
- return fn(*args, **kwargs)
- def is_utils_checkpoint(obj: Any) -> bool:
- # Lazy import to avoid circular dependencies
- import torch.utils.checkpoint
- return obj is torch.utils.checkpoint.checkpoint
- def is_invoke_subgraph(obj: Any) -> bool:
- from torch._higher_order_ops.invoke_subgraph import invoke_subgraph_placeholder
- return obj is invoke_subgraph_placeholder
- def build_invoke_subgraph_variable(**options: Any) -> Any:
- from .variables.higher_order_ops import TorchHigherOrderOperatorVariable
- return TorchHigherOrderOperatorVariable.make(
- torch._higher_order_ops.invoke_subgraph,
- **options,
- )
- def build_checkpoint_variable(**options: Any) -> Any:
- import torch._higher_order_ops.wrap as higher_order_ops
- from .variables.higher_order_ops import TorchHigherOrderOperatorVariable
- # TODO - This is a temporary situation where we have two versions of
- # checkpointing implementation. We will converge on one and remove the other.
- activation_checkpoint_op: torch._ops.HigherOrderOperator = (
- higher_order_ops.tag_activation_checkpoint
- )
- if torch._functorch.config.functionalize_rng_ops:
- activation_checkpoint_op = higher_order_ops.wrap_activation_checkpoint
- return TorchHigherOrderOperatorVariable.make(
- activation_checkpoint_op,
- **options,
- )
- def is_compile_supported(device_type: DeviceLikeType) -> Any:
- from .eval_frame import is_dynamo_supported
- type = torch.device(device_type).type
- compile_supported = is_dynamo_supported()
- if type == "cpu":
- pass
- elif type in ["cuda", "xpu", "mtia"] and compile_supported:
- compile_supported = has_triton()
- else:
- compile_supported = False
- return compile_supported
- # The following 3.11 source code functions are adapted from
- # https://github.com/python/cpython/blob/v3.11.4/Lib/traceback.py
- # in order to output source code corresponding to bytecode in 3.11+.
- # We need our own versions since we want to support multiline expressions.
- def _fix_offset(str: str, offset: int) -> int:
- """
- Convert byte offset `offset` of `str` into character offset.
- Byte offset is used for 3.11+ instruction column data.
- Takes things like unicode characters into consideration.
- Unchanged from CPython implementation.
- """
- as_utf8 = str.encode("utf-8")
- return len(as_utf8[:offset].decode("utf-8", errors="replace"))
- @dataclasses.dataclass
- class _Anchors:
- # inclusive
- left_end_lineno: int
- left_end_offset: int
- right_start_lineno: int
- # exclusive
- right_start_offset: int
- def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]:
- """
- Given source code `segment` corresponding to a bytecode
- instruction, determine:
- - for binary ops, the location of the binary op
- - for indexing, the location of the brackets.
- `segment` is expected to be a valid Python expression
- """
- assert sys.version_info >= (3, 11)
- import ast
- tree: Any | None = None
- try:
- # Without brackets, `segment` is parsed as a statement.
- # We expect an expression, so wrap `segment` in
- # brackets to handle multi-line expressions.
- tree = ast.parse("(\n" + segment + "\n)")
- except SyntaxError:
- return None
- assert tree is not None
- if len(tree.body) != 1:
- return None
- lines = segment.split("\n")
- # get character index given byte offset
- def normalize(lineno: int, offset: int) -> int:
- return _fix_offset(lines[lineno], offset)
- # Gets the next valid character index in `lines`, if
- # the current location is not valid. Handles empty lines.
- def next_valid_char(lineno: int, col: int) -> tuple[int, int]:
- while lineno < len(lines) and col >= len(lines[lineno]):
- col = 0
- lineno += 1
- assert lineno < len(lines) and col < len(lines[lineno])
- return lineno, col
- # Get the next valid character index in `lines`.
- def increment(lineno: int, col: int) -> tuple[int, int]:
- col += 1
- lineno, col = next_valid_char(lineno, col)
- assert lineno < len(lines) and col < len(lines[lineno])
- return lineno, col
- # Get the next valid character at least on the next line
- def nextline(lineno: int, col: int) -> tuple[int, int]:
- col = 0
- lineno += 1
- lineno, col = next_valid_char(lineno, col)
- assert lineno < len(lines) and col < len(lines[lineno])
- return lineno, col
- statement = tree.body[0]
- if isinstance(statement, ast.Expr):
- expr = statement.value
- if isinstance(expr, ast.BinOp):
- # ast gives locations for BinOp subexpressions, e.g.
- # ( left_expr ) + ( right_expr )
- # left^^^^^ right^^^^^
- # -2 since end_lineno is 1-indexed and because we added an extra
- # bracket to `segment` when calling ast.parse
- cur_lineno = cast(int, expr.left.end_lineno) - 2
- assert expr.left.end_col_offset is not None
- cur_col = normalize(cur_lineno, expr.left.end_col_offset)
- cur_lineno, cur_col = next_valid_char(cur_lineno, cur_col)
- # Heuristic to find the operator character.
- # The original CPython implementation did not look for ), \, or #,
- # leading to incorrect anchor location, e.g.
- # (x) + (y)
- # ~~^~~~~~~
- while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#":
- if ch in "\\#":
- cur_lineno, cur_col = nextline(cur_lineno, cur_col)
- else:
- cur_lineno, cur_col = increment(cur_lineno, cur_col)
- # binary op is 1 or 2 characters long, on the same line
- right_col = cur_col + 1
- if (
- right_col < len(lines[cur_lineno])
- and not (ch := lines[cur_lineno][right_col]).isspace()
- and ch not in "\\#"
- ):
- right_col += 1
- # right_col can be invalid since it is exclusive
- return _Anchors(cur_lineno, cur_col, cur_lineno, right_col)
- elif isinstance(expr, ast.Subscript):
- # ast gives locations for value and slice subexpressions, e.g.
- # ( value_expr ) [ slice_expr ]
- # value^^^^^ slice^^^^^
- # subscript^^^^^^^^^^^^^^^^^^^^
- # find left bracket (first '[' after value)
- left_lineno = cast(int, expr.value.end_lineno) - 2
- assert expr.value.end_col_offset is not None
- left_col = normalize(left_lineno, expr.value.end_col_offset)
- left_lineno, left_col = next_valid_char(left_lineno, left_col)
- while lines[left_lineno][left_col] != "[":
- left_lineno, left_col = increment(left_lineno, left_col)
- # find right bracket (final character of expression)
- right_lineno = cast(int, expr.end_lineno) - 2
- assert expr.end_col_offset is not None
- right_col = normalize(right_lineno, expr.end_col_offset)
- return _Anchors(left_lineno, left_col, right_lineno, right_col)
- elif isinstance(expr, ast.Call):
- # ( func_expr ) (args, kwargs)
- # func^^^^^
- # call^^^^^^^^^^^^^^^^^^^^^^^^
- # find left bracket (first '(' after func)
- left_lineno = cast(int, expr.func.end_lineno) - 2
- assert expr.func.end_col_offset is not None
- left_col = normalize(left_lineno, expr.func.end_col_offset)
- left_lineno, left_col = next_valid_char(left_lineno, left_col)
- while lines[left_lineno][left_col] != "(":
- left_lineno, left_col = increment(left_lineno, left_col)
- # find right bracket (final character of expression)
- right_lineno = cast(int, expr.end_lineno) - 2
- assert expr.end_col_offset is not None
- right_col = normalize(right_lineno, expr.end_col_offset)
- return _Anchors(left_lineno, left_col, right_lineno, right_col)
- return None
- def get_instruction_source_311(code: types.CodeType, inst: Instruction) -> str:
- """
- Python 3.11+ only. Returns lines of source code (from code object `code`)
- corresponding to `inst`'s location data, and underlines relevant code to `inst`.
- Example: CALL on `g`:
- f(g(
- ^^
- h(x)))
- ^^^^^
- We need our own implementation in < 3.13 since `format_frame_summary` in
- Python's `traceback` module doesn't handle multi-line expressions
- (and their anchor extraction code is not completely correct).
- """
- if sys.version_info >= (3, 13):
- # multiline traceback implemented in 3.13+
- frame_summary = traceback.FrameSummary(
- code.co_filename,
- inst.positions.lineno,
- code.co_name,
- end_lineno=inst.positions.end_lineno,
- colno=inst.positions.col_offset,
- end_colno=inst.positions.end_col_offset,
- )
- result = traceback.format_list([frame_summary])[0]
- # remove first line containing filename info
- result = "\n".join(result.splitlines()[1:])
- # indent lines with original indentation
- orig_lines = [
- linecache.getline(code.co_filename, lineno).rstrip()
- for lineno in range(inst.positions.lineno, inst.positions.end_lineno + 1)
- ]
- orig_lines_dedent = textwrap.dedent("\n".join(orig_lines)).splitlines()
- indent_len = len(orig_lines[0]) - len(orig_lines_dedent[0])
- indent = orig_lines[0][:indent_len]
- result = textwrap.indent(textwrap.dedent(result), indent)
- return result
- assert hasattr(inst, "positions") and inst.positions is not None
- if inst.positions.lineno is None:
- return ""
- # The rstrip + "\n" pattern is used throughout this function to handle
- # linecache.getline errors. Error lines are treated as empty strings "", but we want
- # to treat them as blank lines "\n".
- first_line = linecache.getline(code.co_filename, inst.positions.lineno).rstrip()
- if inst.positions.end_lineno is None:
- return first_line
- if inst.positions.col_offset is None or inst.positions.end_col_offset is None:
- return first_line
- # character index of the start of the instruction
- start_offset = _fix_offset(first_line, inst.positions.col_offset)
- # character index of the end of the instruction
- # compute later since end may be a different line
- end_offset = None
- # expression corresponding to the instruction so we can get anchors
- segment = ""
- # underline markers to be printed - start with `~` marker and replace with `^` later
- markers = []
- # Compute segment and initial markers
- if inst.positions.end_lineno == inst.positions.lineno:
- end_offset = _fix_offset(first_line, inst.positions.end_col_offset)
- segment = first_line[start_offset:end_offset]
- markers.append(" " * start_offset + "~" * (end_offset - start_offset))
- else:
- segment = first_line[start_offset:] + "\n"
- markers.append(" " * start_offset + "~" * (len(first_line) - start_offset))
- last_line = linecache.getline(
- code.co_filename, inst.positions.end_lineno
- ).rstrip()
- end_offset = _fix_offset(last_line, inst.positions.end_col_offset)
- for lineno in range(inst.positions.lineno + 1, inst.positions.end_lineno):
- line = linecache.getline(code.co_filename, lineno).rstrip()
- segment += line + "\n"
- # don't underline leading spaces
- num_spaces = len(line) - len(line.lstrip())
- markers.append(" " * num_spaces + "~" * (len(line) - num_spaces))
- segment += last_line[:end_offset]
- num_spaces = len(last_line) - len(last_line.lstrip())
- markers.append(" " * num_spaces + "~" * (end_offset - num_spaces))
- anchors: Optional[_Anchors] = None
- try:
- anchors = _extract_anchors_from_expr(segment)
- except AssertionError:
- pass
- # replace `~` markers with `^` where necessary
- if anchors is None:
- markers = [marker.replace("~", "^") for marker in markers]
- else:
- # make markers mutable
- mutable_markers: list[list[str]] = [list(marker) for marker in markers]
- # anchor positions do not take start_offset into account
- if anchors.left_end_lineno == 0:
- anchors.left_end_offset += start_offset
- if anchors.right_start_lineno == 0:
- anchors.right_start_offset += start_offset
- # Turn `~`` markers between anchors to `^`
- for lineno in range(len(markers)):
- for col in range(len(mutable_markers[lineno])):
- if lineno < anchors.left_end_lineno:
- continue
- if lineno == anchors.left_end_lineno and col < anchors.left_end_offset:
- continue
- if (
- lineno == anchors.right_start_lineno
- and col >= anchors.right_start_offset
- ):
- continue
- if lineno > anchors.right_start_lineno:
- continue
- if mutable_markers[lineno][col] == "~":
- mutable_markers[lineno][col] = "^"
- # make markers into strings again
- markers = ["".join(marker) for marker in mutable_markers]
- result = ""
- for i in range(len(markers)):
- result += (
- linecache.getline(code.co_filename, inst.positions.lineno + i).rstrip()
- + "\n"
- )
- result += markers[i] + "\n"
- return result
- def get_static_address_type(t: Any) -> Any:
- if isinstance(t, torch.Tensor):
- return getattr(t, "_dynamo_static_input_type", None)
- return None
- def is_rng_state_getter_or_setter(value: Any) -> bool:
- getters = (
- # The following two functions are not identical, so don't remove anyone!
- torch._C.Generator.get_state,
- torch.default_generator.get_state,
- torch.get_rng_state,
- torch.cuda.get_rng_state,
- )
- setters = (
- torch._C.Generator.set_state,
- torch.default_generator.set_state,
- torch.set_rng_state,
- torch.cuda.set_rng_state,
- )
- return value in (*setters, *getters)
- def is_tensor_base_attr_getter(value: Any) -> bool:
- return (
- isinstance(value, types.MethodWrapperType)
- and value.__name__ == "__get__"
- and value.__self__.__objclass__ is torch._C._TensorBase # type: ignore[attr-defined]
- )
- def is_tensor_getset_descriptor(name: str) -> bool:
- try:
- attr = inspect.getattr_static(torch.Tensor, name)
- return type(attr) is types.GetSetDescriptorType
- except AttributeError:
- return False
- def is_torch_function_object(value: Any) -> bool:
- return hasattr(value, "__torch_function__")
- def has_torch_function(vt: VariableTracker) -> bool:
- # This emulates
- # https://github.com/pytorch/pytorch/blob/8d81806211bc3c0ee6c2ef235017bacf1d775a85/torch/csrc/utils/disable_torch_function.cpp#L315-L323
- from torch._dynamo.variables import UserDefinedObjectVariable
- from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable
- # Note on lazy vars: The value will either be realized or not throughout the course of execution
- # if the value has a torch function, it will eventually be realized so we can realize it here
- # if the value does not have a torch function, it may or may not be realized
- # if it is realized it will be used and guards will be installed properly
- # if it is not used, guards won't be installed, and it doesn't matter
- # if the value has a torch function or not, so we should *not* realize it.
- # NB: We technically know that if is_realized is False, LazyVariableTracker has the peek_value method
- # but mypy does not unfortunately
- if vt.is_realized() or (
- hasattr(vt, "peek_value") and hasattr(vt.peek_value(), "__torch_function__")
- ):
- func = None
- if isinstance(vt, TensorWithTFOverrideVariable):
- func = getattr(vt.class_type, "__torch_function__", None)
- elif isinstance(vt, UserDefinedObjectVariable):
- func = getattr(vt.value, "__torch_function__", None)
- return func not in (None, torch._C._disabled_torch_function_impl)
- return False
- # see note [Tensor Fakification and Symbol Caching]
- def to_fake_tensor(
- t: torch.Tensor, fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
- ) -> Any:
- symbolic_context = None
- source = None
- if tracing_context := torch._guards.TracingContext.try_get():
- if t in tracing_context.tensor_to_context:
- symbolic_context = tracing_context.tensor_to_context[t]
- source = symbolic_context.tensor_source
- return fake_mode.from_tensor(
- t, static_shapes=False, symbolic_context=symbolic_context, source=source
- )
- # NB: this works for both classes and instances
- def is_frozen_dataclass(value: Any) -> bool:
- return (
- not object_has_getattribute(value)
- and not class_has_getattribute(value)
- and is_dataclass(value)
- and hasattr(value, "__dataclass_params__")
- and hasattr(value.__dataclass_params__, "frozen")
- and value.__dataclass_params__.frozen
- )
- def get_first_attr(obj: Any, *attrs: str) -> Any:
- """
- Return the first available attribute or throw an exception if none is present.
- """
- for attr in attrs:
- if hasattr(obj, attr):
- return getattr(obj, attr)
- raise AssertionError(f"{obj} does not has any of the attributes: {attrs}")
- @contextlib.contextmanager
- def maybe_enable_compiled_autograd(
- should_enable: bool, fullgraph: bool = True, dynamic: bool = True
- ) -> Generator[Any, None, None]:
- if not should_enable:
- yield
- else:
- def compiler_fn(gm: Any) -> Any:
- def inner_compiler(gm_: Any, example_inputs_: Any) -> Any:
- torch._dynamo.utils.counters["compiled_autograd"]["compiles"] += 1
- return torch._inductor.compile(gm_, example_inputs_)
- return torch.compile(
- gm, backend=inner_compiler, fullgraph=fullgraph, dynamic=dynamic
- )
- with torch._dynamo.compiled_autograd._enable(compiler_fn) as ctx:
- yield ctx
- def invalid_removeable_handle() -> RemovableHandle:
- # need a subclass so weakref works
- class Invalid(dict): # type: ignore[type-arg]
- pass
- return RemovableHandle(Invalid())
- # Returns a "proxy" (new object with the same class and dict) for (non-GraphModule) nn.Module's.
- # Attribute changes to the original object/proxy will be reflected in the other.
- # This is useful for cases where we want a keep-alive reference to a module without increasing
- # its reference count.
- def nn_module_proxy(mod: Any) -> Any:
- if not isinstance(mod, torch.nn.Module):
- return mod
- if isinstance(mod, torch.fx.GraphModule):
- # Dynamo-generated GM's shouldn't contain user-created GM's
- return mod
- proxy = mod.__class__.__new__(mod.__class__)
- proxy.__dict__ = mod.__dict__
- return proxy
- class GmWrapper(torch.nn.Module):
- def __init__(
- self, gm: torch.fx.GraphModule, unflatten_fn: Callable[[list[Any]], Any]
- ) -> None:
- super().__init__()
- self.gm = gm
- self.unflatten_fn = unflatten_fn
- def forward(self, *args: Any) -> Any:
- # pyrefly: ignore [annotation-mismatch, redefinition]
- args: list[Any] = list(args)
- return self.gm(*self.unflatten_fn(args))
- def flatten_graph_inputs(
- gm: torch.fx.GraphModule, inputs: Any, compile_gm: Callable[[Any, Any], Any]
- ) -> Callable[..., Any]:
- """
- Mutate inputs so that they are flat and wrap gm such that it
- accepts those inputs. This is needed for graphs that take
- bumpy inputs.
- """
- inputs_idx_to_clear = [
- i
- for i, node in enumerate(gm.graph.nodes)
- if node.op == "placeholder" and node.meta.get("steal_arg", False)
- ]
- if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
- # fast path, avoid pytree overhead
- # compiled autograd inputs are always a list of tensors, maybe followed by symints
- assert inputs_idx_to_clear == [0]
- assert isinstance(inputs[0], list)
- boxed_inputs_count = len(inputs[0])
- def flatten_fn(args: Any) -> Any:
- return args[0] + list(args[1:])
- def unflatten_fn(flat_args: Any) -> Any:
- return (flat_args[:boxed_inputs_count], *flat_args[boxed_inputs_count:])
- compiled_fn = compile_gm(GmWrapper(gm, unflatten_fn), flatten_fn(inputs))
- else:
- # slow path, don't know inputs structure
- flat_inputs, spec = pytree.tree_flatten(inputs)
- unflatten_fn = functools.partial(pytree.tree_unflatten, treespec=spec)
- compiled_fn = compile_gm(GmWrapper(gm, unflatten_fn), flat_inputs)
- # note this doesn't check the spec, assuming it is the same
- flatten_fn = pytree.arg_tree_leaves
- def wrapper(*args: Any) -> Any:
- flat_args = flatten_fn(args)
- # flat_args is a new list, so we need to clear references from the old list
- for i in inputs_idx_to_clear:
- args[i].clear()
- # this call is boxed to avoid increasing refcount until we reach aot_module_simplified forward
- return compiled_fn(flat_args)
- return wrapper
- def get_locals_to_steal(maybe_gm: Any) -> list[Any]:
- if not isinstance(maybe_gm, torch.fx.GraphModule) or not hasattr(maybe_gm, "meta"):
- return []
- return maybe_gm.meta.get("locals_to_steal", [])
- def set_locals_to_steal(gm: torch.fx.GraphModule, locals_to_steal: list[Any]) -> None:
- gm.meta["locals_to_steal"] = locals_to_steal
- class Lit:
- def __init__(self, s: str) -> None:
- self.s = s
- def __repr__(self) -> str:
- return self.s
- warn_once_cache: set[str] = set()
- def warn_once(msg: str, stacklevel: int = 1) -> None:
- # Dynamo causes all warnings.warn (in user code and in Dynamo code) to print all the time.
- # https://github.com/pytorch/pytorch/issues/128427.
- # warn_once is a workaround: if the msg has been warned on before, then we will not
- # warn again.
- # NB: it's totally ok to store a cache of all the strings: this is what warnings.warn does as well.
- if msg in warn_once_cache:
- return
- warn_once_cache.add(msg)
- warnings.warn(msg, stacklevel=stacklevel + 1)
- def strip_color_from_string(text: str) -> str:
- # This regular expression matches ANSI escape codes
- ansi_escape = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]")
- return ansi_escape.sub("", text)
- @contextlib.contextmanager
- def _disable_saved_tensors_hooks_during_tracing() -> Generator[None, None, None]:
- # See NOTE: [Deferring tensor pack/unpack hooks until runtime]
- try:
- prior = torch._C._autograd._saved_tensors_hooks_set_tracing(True)
- yield
- finally:
- torch._C._autograd._saved_tensors_hooks_set_tracing(prior)
- def is_parameter_freezing() -> bool:
- return torch._inductor.config.freezing and not torch.is_grad_enabled()
- def get_torch_function_mode_stack() -> list[Any]:
- return [
- get_torch_function_mode_stack_at(i) for i in range(_len_torch_function_stack())
- ]
- def get_torch_function_mode_stack_at(ind: int) -> Any:
- assert ind < _len_torch_function_stack() and ind >= 0
- return torch._C._get_function_stack_at(ind)
- def set_torch_function_mode_stack(stack: list[Any]) -> None:
- for _ in range(_len_torch_function_stack()):
- _pop_torch_function_stack()
- for mode in stack:
- _push_on_torch_function_stack(mode)
- def clear_torch_function_mode_stack() -> None:
- for _ in range(_len_torch_function_stack()):
- _pop_torch_function_stack()
- def get_current_stream(device: torch.device) -> torch.Stream:
- return torch.accelerator.current_stream(device)
- # call from C dynamo in order to inspect values in pdb
- def _breakpoint_for_c_dynamo(*args: Any) -> None:
- breakpoint()
- def verify_guard_fn_signature(value: Any) -> None:
- fn = value.__metadata_guard__
- sig = inspect.signature(fn)
- if len(sig.parameters) != 2:
- from .exc import InternalTorchDynamoError
- raise InternalTorchDynamoError(
- "Tensor subclass method __metadata_guard__ must take exactly two subclass metadata arguments"
- )
- if fn.__self__ != value.__class__:
- from .exc import InternalTorchDynamoError
- raise InternalTorchDynamoError(
- "Tensor subclass method __metadata_guard__ must be a classmethod"
- )
- def does_not_override_dict_iter_methods(user_cls: Any) -> bool:
- return (
- user_cls.items in (dict.items, OrderedDict.items)
- and user_cls.values in (dict.values, OrderedDict.values)
- and user_cls.keys in (dict.keys, OrderedDict.keys)
- and user_cls.__iter__ in (dict.__iter__, OrderedDict.__iter__)
- )
- # Helper functions below are to prevent TorchDynamo to prevent tracing of
- # __torch_function__ calls triggered on tensor properties in the pre graph
- # bytecode.
- @torch._disable_dynamo
- def call_size(x: Any, i: int) -> int:
- return x.size(i)
- @torch._disable_dynamo
- def call_stride(x: Any, i: int) -> int:
- return x.stride(i)
- @torch._disable_dynamo
- def call_storage_offset(x: Any) -> int:
- return x.storage_offset()
- # Helper function to extract relevant parts of a tensor's __dict__ to store in node meta.
- # To avoid ref cycles, it's important that no tensors are present here, so leave those out.
- def _extract_tensor_dict(t: torch.Tensor) -> dict[str, Any]:
- KEYS_TO_COPY = [
- "_dynamo_static_input_type",
- "tag",
- ]
- tensor_dict = {
- key: copy.copy(t.__dict__[key]) for key in KEYS_TO_COPY if key in t.__dict__
- }
- return tensor_dict
- def build_stream(args: tuple[Any], kwargs: dict[Any, Any]) -> torch.Stream:
- return torch._C.Stream(*args, **kwargs)
- def build_event(args: tuple[Any], kwargs: dict[Any, Any]) -> torch.Event:
- return torch._C.Event(*args, **kwargs)
- class CompileTimeInstructionCounter:
- _counter: int = 0
- _id: int = -1
- _depth = 0
- @classmethod
- def start(cls) -> None:
- cls._depth = cls._depth + 1
- if cls._depth == 1:
- cls._id = _instruction_counter.start()
- @classmethod
- def end(cls) -> None:
- cls._depth = cls._depth - 1
- if cls._depth == 0:
- cls._counter += _instruction_counter.end(cls._id)
- cls._id = -1
- @classmethod
- def clear(cls) -> None:
- cls._counter = 0
- @classmethod
- def value(cls) -> int:
- return cls._counter
- @classmethod
- @contextmanager
- def record(cls) -> Generator[None, None, None]:
- try:
- if config.record_compile_time_instruction_count:
- cls.start()
- yield
- finally:
- if config.record_compile_time_instruction_count:
- cls.end()
- class CompileCounterInt(int):
- def __add__(self, other: Any) -> CompileCounterInt:
- return CompileCounterInt(super().__add__(other))
- def set_feature_use(feature: str, usage: bool) -> None:
- """
- Records whether we are using a feature
- Generally a feature is a JK.
- """
- # Note that sometimes (tests etc...) we're not in a context which we can record into
- if get_metrics_context().in_progress():
- get_metrics_context().set_key_value("feature_usage", feature, usage)
- _ddp_optimization_mode: tuple[str, ...] = (
- "ddp_optimizer",
- "python_reducer", # experimental mode
- "python_reducer_without_compiled_forward",
- "no_optimization",
- )
- def get_optimize_ddp_mode() -> str:
- optimize_ddp = config.optimize_ddp
- if isinstance(optimize_ddp, bool):
- mode = "ddp_optimizer" if optimize_ddp else "no_optimization"
- elif isinstance(optimize_ddp, str):
- mode = optimize_ddp
- else:
- raise ValueError(
- f"Invalid dynamo config optimize_ddp type {type(optimize_ddp)=}"
- )
- assert mode in _ddp_optimization_mode, (
- f"Invalid dynamo config optimize_ddp value {mode=}"
- )
- return mode
- @contextmanager
- def maybe_disable_inference_mode() -> Generator[None, None, None]:
- """
- Disables torch.inference_mode for the compilation (still on at runtime).
- This simplifies the compile stack where we can assume that inference_mode
- will always be off.
- Since inference_mode is equivalent to no_grad + some optimizations (version
- counts etc), we turn on no_grad here. The other optimizations are not
- relevant to torch.compile.
- """
- is_inference_mode_on = (
- config.fake_tensor_disable_inference_mode and torch.is_inference_mode_enabled()
- )
- if is_inference_mode_on:
- with (
- torch.inference_mode(False),
- torch.no_grad(),
- ):
- yield
- else:
- yield
- @contextmanager
- def maybe_disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
- """
- Turns off tracking of inference_mode for fake tensor propagation. With this
- context manager, when a real tensor is converted to fake tensor, the fake
- tensor looses its inference-ness.
- """
- if config.fake_tensor_disable_inference_mode:
- with torch._subclasses.meta_utils.disable_inference_mode_for_fake_prop():
- yield
- else:
- yield
- def is_node_meta_valid(node: Optional[torch.fx.Node]) -> bool:
- return node is None or "example_value" in node.meta or "val" in node.meta
- # If True, enforce fullgraph=True - raise errors on graph break
- _error_on_graph_break = False
- def _get_error_on_graph_break() -> bool:
- return _error_on_graph_break
- def _set_error_on_graph_break(value: bool) -> None:
- global _error_on_graph_break
- _error_on_graph_break = value
- @torch._disable_dynamo
- def record_pregraph_bytecode_enter() -> AbstractContextManager[None]:
- cm: AbstractContextManager[None] = (
- torch._C._profiler._RecordFunctionFast("Pregraph bytecode")
- if torch.autograd.profiler._is_profiler_enabled
- else contextlib.nullcontext()
- )
- cm.__enter__()
- return cm
- @torch._disable_dynamo
- def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None:
- cm.__exit__(None, None, None)
- # Returns a set of code objects present traced in the current TracingContext, or None
- # if there is no current TracingContext.
- def get_traced_code() -> Optional[list[CodeType]]:
- from torch._guards import TracingContext
- return TracingContext.get_traced_code()
- def raise_on_overridden_hash(obj: Any, vt: VariableTracker) -> None:
- from . import graph_break_hints
- from .exc import unimplemented
- is_overridden = type(obj).__dict__.get("__hash__", False)
- if is_overridden:
- unimplemented(
- gb_type="User-defined object with overridden __hash__",
- context=f"hashing object of type={type(obj)} and variable tracker {vt}",
- explanation=f"Found a user-defined object {vt} with overridden __hash__ when attempting to hash it",
- hints=[
- "Dynamo does not support hashing user-defined objects with overridden __hash__",
- *graph_break_hints.SUPPORTABLE,
- ],
- )
|