| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956 |
- # mypy: allow-untyped-defs
- from __future__ import annotations
- """
- This file does three things:
- - Contains the definition of SymNode
- - Installs all the magic methods into SymBool, SymFloat, SymFloat at import time
- - Does not depend on sympy at import time
- As this file is imported from within torch/__init__.py we do not want it to depend on SymPy
- to avoid having to load SymPy at import time, as doing so is *very* slow.
- """
- import builtins
- import functools
- import inspect
- import itertools
- import logging
- import math
- import operator
- import sys
- from functools import lru_cache, update_wrapper
- from typing import Optional, TYPE_CHECKING, Union
- import torch
- import torch._logging.structured as structured
- # NB: The sym_* functions are used via getattr() and must be imported here.
- from torch import ( # noqa: F401
- sym_float,
- sym_ite,
- sym_max,
- sym_min,
- sym_not,
- SymBool,
- SymFloat,
- SymInt,
- )
- from torch._logging import dtrace_structured
- if TYPE_CHECKING:
- from torch.fx.experimental.symbolic_shapes import ShapeEnv
- log = logging.getLogger(__name__)
- sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
- # Sentinel value to indicate "don't compute hint" vs actual None
- # When passed as hint to SymNode, it means we already know hint is unavailable
- # and should not waste time calling compute_hint()
- _NO_HINT: object = object()
- # Type alias for hint values (including the sentinel)
- HintType = bool | float | int | None
- __all__ = ["SymNode", "method_to_operator", "magic_methods", "DynamicInt"]
- from torch.types import py_sym_types as SymTypes
- def _to_symtype(t):
- if t is bool:
- return SymBool
- if t is int:
- return SymInt
- if t is float:
- return SymFloat
- return t
- # TODO: An incomplete list
- # 1. Set variables to be equal when we do equality
- # 2. Specialize on 0/1 when we do subtraction
- class SymNode:
- """
- This is a type erased SymInt/SymFloat which we use to do actual operations.
- End users don't touch this. Magic methods are NOT defined on this object.
- """
- # Note [optimized_summation]: indicates that SymNode is an Add expression of the form
- # a + b + c + d... etc where all terms are unique symbols. This allows us to do some optimizations
- # for common patterns see _optimized_add.
- # The unfortunate reason we have this here is because sympy sets __slots__ = () for add expression,
- # so we cannot add the attribute directly to the sympy expression. Furthermore, we cannot use it as
- # a weak dictionary key either! So instead, we attach the attribute here to the SymNode.
- _optimized_summation: bool = False
- def __init__(
- self,
- expr,
- shape_env,
- pytype,
- hint: Optional[Union[int, float, bool]],
- constant=None,
- fx_node=None,
- optimized_summation=False,
- ):
- self._expr = expr
- self.shape_env = shape_env
- self.pytype = pytype
- self._optimized_summation = optimized_summation
- # What's the difference between hint and constant?
- #
- # - A constant is known to be invariant across invocations of the model;
- # it will always be this value. We only really know this when we
- # encounter an honest-to-goodness literal (when wrapping it into
- # a SymNode, we set constant.) Most of the time, constant is None
- #
- # - A hint is a *particular* value from the particular run we are
- # tracing, but it may vary the next time around. It's useful to
- # keep this around, as if we need a concrete value from a SymNode,
- # we will return the hint and guard on the expression that produced
- # it giving the same hint next time around. The hint is not
- # guaranteed to be set either: if you have an unbacked SymNode,
- # there won't be any hint; it was the result of some tensor-dependent
- # computation, but we don't know what it actually is because we
- # haven't actually run the tensor computation.
- #
- # If _hint is None, we will query maybe_evaluate_static(compute_hint=True)
- # in hopes that we've learned enough about the unbacked symints to
- # discharge the hint; otherwise, you're likely to just error out.
- #
- # (A previous version of this system had some optimizations to only
- # recompute when it was possible we had learned enough about the
- # unbacked symint that a hint was now possible, but as we added more
- # potential refinements to unbacked symints this got harder to keep
- # in sync, so we've deleted it for now.)
- def compute_hint():
- from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
- # This occasionally gets exercised by, e.g.,
- # convert_shape_to_symint. It's just a nicety so you don't HAVE
- # to have a correct hint on hand when making a SymNode.
- # Don't attempt to compute for unbacked, this can be quite
- # expensive.
- if has_free_unbacked_symbols(self.expr):
- return None
- hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
- if hint is not None:
- hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint
- return hint
- if hint is _NO_HINT:
- # Caller explicitly indicates hint is unavailable, don't compute
- hint = None
- elif hint is not None:
- if not (type(hint) is pytype or type(hint) is _to_symtype(pytype)):
- raise AssertionError(
- "Cannot create SymNode of type "
- f"{pytype} with incompatible hint of type {type(hint)}"
- )
- if self.shape_env and self.shape_env._translation_validation_enabled:
- # This is technically not TV, but this assert is expensive so
- # let's only do it when we're already doing expensive things
- computed_hint = compute_hint()
- if hint != computed_hint:
- raise AssertionError(f"{hint} != {computed_hint} (for {self.expr})")
- else:
- hint = compute_hint()
- self._hint = hint
- self.constant: Optional[Union[int, float, bool]] = constant
- # Record the FX node of the current node if we are doing translation
- # validation. They will be used for building the input assertions for
- # the translation validation problem.
- tx_validation_en = (
- self.shape_env and self.shape_env._translation_validation_enabled
- )
- self.fx_node = tx_validation_en and fx_node
- def with_shape_env(self, shape_env: ShapeEnv) -> SymNode:
- return SymNode(
- self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
- )
- def _value_eq(self, other: SymNode) -> bool:
- # Purposely don't include the shape_env in the eq.
- return (
- self._expr == other._expr
- and self.pytype == other.pytype
- and self._hint == other._hint
- and self.constant == other.constant
- and self.fx_node == other.fx_node
- )
- def _value_hash(self) -> int:
- # Purposely don't include the shape_env in the hash.
- return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
- @property
- def expr(self):
- return self.shape_env.replace(self._expr)
- @property
- def hint(self):
- return self._hint
- def has_hint(self):
- return self._hint is not None
- def require_hint(self, fallback=None):
- from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
- if self._hint is None:
- if fallback is not None:
- # Say we have some expr like 2*u0 + s0
- # The hint will be None, since the expr contains at least 1 unbacked.
- # We will:
- # - replace every backed free symbol with its corresponding hint
- # - replace every unbacked free symbol with the fallback
- # - regenerate the expression with those symbol replacements
- # Note: this is not really complete either, since right now
- # this logic does not take into account any value ranges
- # for the unbacked symints, we may need to beef it up at some point.
- unbacked_symbols = free_unbacked_symbols(self.expr)
- replacements = {
- s: fallback
- if s in unbacked_symbols
- else self.shape_env.backed_var_to_val[s]
- for s in self.expr.free_symbols
- }
- return int(self.expr.xreplace(replacements))
- # NB: we expect this to raise
- return self.shape_env.size_hint(self.expr)
- return self._hint
- def maybe_as_int(self):
- if self.expr.is_number:
- return int(self.expr)
- else:
- return None
- # NB: This does conversions, not sure if this is good or not
- def maybe_as_float(self):
- import sympy
- if isinstance(self.expr, sympy.Float):
- return float(self.expr)
- else:
- return None
- def maybe_as_bool(self):
- import sympy
- if self.expr is sympy.true:
- return True
- elif self.expr is sympy.false:
- return False
- else:
- return None
- def is_int(self):
- return self.pytype is int
- def is_float(self):
- return self.pytype is float
- def is_bool(self):
- return self.pytype is bool
- def is_nested_int(self):
- # Unbacked SymInts cannot be nested int today
- return (
- self._hint is not None
- and isinstance(self._hint, SymInt)
- and self._hint.node.is_nested_int()
- )
- def wrap_int(self, num):
- if type(num) is not int:
- raise AssertionError(f"Expected int, got {type(num)}")
- import sympy
- return SymNode(
- sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num
- )
- def wrap_float(self, num):
- if type(num) is not float:
- raise AssertionError(f"Expected float, got {type(num)}")
- import sympy
- return SymNode(
- sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num
- )
- def wrap_bool(self, num):
- if type(num) is not bool:
- raise AssertionError(f"Expected bool, got {type(num)}")
- import sympy
- return SymNode(
- sympy.true if num else sympy.false,
- self.shape_env,
- bool,
- num,
- constant=num,
- fx_node=num,
- )
- def clone(self):
- return self
- def str(self):
- return f"{self.expr}"
- def __str__(self):
- return self.str()
- def __repr__(self):
- rep = [
- f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}",
- ]
- if self._hint is not None:
- rep.append(f"hint={self._hint}")
- if self.constant is not None:
- rep.append(f"constant={self.constant}")
- if self.fx_node is not None:
- rep.append(f"fx_node={self.fx_node}")
- return ", ".join(rep) + ")"
- def _graph_repr(self) -> builtins.str:
- # Representation used by GraphModule to create a pythonic version of a graph
- return self.str()
- # These methods call the metaprogrammed methods, they're hand written
- # here so we get good stack traces
- def abs(self) -> SymNode:
- return self._abs() # type: ignore[attr-defined]
- def pos(self) -> SymNode:
- return self._pos() # type: ignore[attr-defined]
- def round(self, ndigits=None) -> SymNode:
- return self._round(ndigits) # type: ignore[attr-defined]
- def trunc(self) -> SymNode:
- return self._trunc() # type: ignore[attr-defined]
- def add(self, other) -> SymNode:
- return self._add(other) # type: ignore[attr-defined]
- def sub(self, other) -> SymNode:
- return self._sub(other) # type: ignore[attr-defined]
- def mul(self, other) -> SymNode:
- return self._mul(other) # type: ignore[attr-defined]
- def mod(self, other) -> SymNode:
- return self._mod(other) # type: ignore[attr-defined]
- def float_pow(self, other) -> SymNode:
- return self._float_pow(other) # type: ignore[attr-defined]
- def pow_by_natural(self, other) -> SymNode:
- return self._pow_by_natural(other) # type: ignore[attr-defined]
- def and_(self, other) -> SymNode:
- return self._and_(other) # type: ignore[attr-defined]
- def or_(self, other) -> SymNode:
- return self._or_(other) # type: ignore[attr-defined]
- def float_truediv(self, other) -> SymNode:
- return self._float_truediv(other) # type: ignore[attr-defined]
- def int_truediv(self, other) -> SymNode:
- return self._int_truediv(other) # type: ignore[attr-defined]
- def int_floordiv(self, other) -> SymNode:
- return self._int_floordiv(other) # type: ignore[attr-defined]
- def lshift(self, other) -> SymNode:
- return self._lshift(other) # type: ignore[attr-defined]
- def rshift(self, other) -> SymNode:
- return self._rshift(other) # type: ignore[attr-defined]
- def sym_not(self) -> SymNode: # noqa: F811
- return self._sym_not() # type: ignore[attr-defined]
- def eq(self, other) -> SymNode:
- return self._eq(other) # type: ignore[attr-defined]
- def ne(self, other) -> SymNode:
- return self._ne(other) # type: ignore[attr-defined]
- def gt(self, other) -> SymNode:
- return self._gt(other) # type: ignore[attr-defined]
- def lt(self, other) -> SymNode:
- return self._lt(other) # type: ignore[attr-defined]
- def le(self, other) -> SymNode:
- return self._le(other) # type: ignore[attr-defined]
- def ge(self, other) -> SymNode:
- return self._ge(other) # type: ignore[attr-defined]
- def floor(self) -> SymNode:
- return self._floor() # type: ignore[attr-defined]
- def is_integer(self) -> SymNode:
- return self._is_integer() # type: ignore[attr-defined]
- def sym_float(self) -> SymNode: # noqa: F811
- return self._sym_float() # type: ignore[attr-defined]
- def sym_int(self) -> SymNode:
- return self._sym_int() # type: ignore[attr-defined]
- def ceil(self) -> SymNode:
- return self._ceil() # type: ignore[attr-defined]
- def neg(self) -> SymNode:
- return self._neg() # type: ignore[attr-defined]
- def sym_min(self, other) -> SymNode: # noqa: F811
- return self._sym_min(other) # type: ignore[attr-defined]
- def sym_max(self, other) -> SymNode: # noqa: F811
- return self._sym_max(other) # type: ignore[attr-defined]
- def sym_ite(self, then_val, else_val) -> SymNode:
- return self._sym_ite(then_val, else_val) # type: ignore[attr-defined]
- def is_contiguous(self, sizes, strides) -> SymNode:
- return self._is_contiguous(sizes, strides) # type: ignore[attr-defined]
- def is_channels_last_contiguous_2d(self, sizes, strides) -> SymNode:
- return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined]
- def is_channels_last_contiguous_3d(self, sizes, strides) -> SymNode:
- return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined]
- def is_channels_last_strides_2d(self, sizes, strides) -> SymNode:
- return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined]
- def is_channels_last_strides_3d(self, sizes, strides) -> SymNode:
- return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined]
- def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> SymNode:
- return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined]
- # Make C++ happy
- def sym_or(self, other):
- return self.or_(other)
- def sym_and(self, other):
- return self.and_(other)
- # Integer bitwise ops
- def bitwise_and(self, other):
- return self._bitwise_and(other) # type: ignore[attr-defined]
- def bitwise_or(self, other):
- return self._bitwise_or(other) # type: ignore[attr-defined]
- def bitwise_xor(self, other):
- return self._bitwise_xor(other) # type: ignore[attr-defined]
- # There is no int_truediv available from C++
- def truediv(self, other):
- return self.float_truediv(other)
- def floordiv(self, other) -> SymNode:
- return self.int_floordiv(other)
- # We didn't bind integer pow in C++
- def pow(self, other):
- return self.float_pow(other)
- def is_non_overlapping_and_dense(self, sizes, strides):
- return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(
- to_node(self, 1)
- ) # type: ignore[attr-defined]
- def int_(self):
- return self.guard_int("", 0) # NB: uses Python backtrace
- # This one is currently done by hand, but if we add other variadic
- # functions consider factoring it out to be metaprogrammed too. Note that
- # some load bearing logic is directly in torch.sym_sum
- def sym_sum(self, args) -> SymNode:
- import sympy
- # Inner impl
- from torch.fx.experimental.proxy_tensor import (
- get_proxy_mode,
- handle_sym_dispatch,
- )
- if get_proxy_mode():
- return to_node(
- self,
- handle_sym_dispatch(
- torch.sym_sum,
- (tuple(wrap_node(a) for a in args),),
- {},
- ),
- )
- exprs = [a.expr for a in args]
- out = sympy.Add(*exprs)
- size_hints = []
- out_hint = None
- for a in args:
- if a.hint is None:
- break
- size_hints.append(a.hint)
- else:
- out_hint = sum(size_hints)
- fx_node, _ = self.shape_env._create_fx_call_function(
- torch.sym_sum, (tuple(a.fx_node for a in args),)
- )
- # NB: Only for integers!
- return SymNode(out, self.shape_env, int, out_hint, fx_node=fx_node)
- def evaluate(self, size_oblivious=False):
- return self.shape_env.evaluate_sym_node(self, size_oblivious)
- # You can manually trigger a guard with this function
- def guard_int(self, file, line):
- # TODO: use the file/line for some useful diagnostic on why a
- # guard occurred
- r = self.evaluate()
- try:
- return int(r)
- except Exception:
- log.warning("Failed to convert to int: %s", r)
- raise
- def guard_float(self, file, line):
- # TODO: use the file/line for some useful diagnostic on why a
- # guard occurred
- r = self.evaluate()
- try:
- return float(r)
- except Exception:
- log.warning("Failed to convert to float: %s", r)
- raise
- def guard_bool(self, file, line):
- # TODO: use the file/line for some useful diagnostic on why a
- # guard occurred
- r = self.evaluate()
- try:
- return bool(r)
- except Exception:
- log.warning("Failed to convert to bool: %s", r)
- raise
- def expect_true(self, file, line):
- from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
- if (
- self.has_hint()
- and not free_unbacked_symbols(self.expr)
- and not self.shape_env.prefer_deferred_runtime_asserts_over_guards
- ):
- # OK to generate guards
- return self.guard_bool(file, line)
- # Generate a deferred runtime assert (this might actually end up doing
- # a regular guard if we can!)
- # TODO: file/line here is very important, because the assert has been
- # deferred so you can't backtrace easily
- return self.shape_env.guard_or_defer_runtime_assert(
- self.expr, f"{file}:{line}", fx_node=self.fx_node
- )
- def statically_known_true(self, file, line):
- from torch.fx.experimental.symbolic_shapes import statically_known_true
- if not self.is_bool():
- raise AssertionError("Expected bool type")
- return statically_known_true(SymBool(self))
- def guard_size_oblivious(self, file, line):
- """
- Like guard_bool, but if we encounter unbacked symbols, if those symbols
- are size-like, we will treat them as >= 2 for the purposes of the analysis.
- This CHANGES the runtime semantics, but all size-oblivious sites have been
- audited to ensure that the runtime semantics don't change in a material way.
- Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping
- an unbacked one size, or a tensor reporting as non-contiguous even if it's
- contiguous if it would have been reported contiguous due to being empty.
- """
- # TODO: use the file/line for some useful diagnostic on why a
- # guard occurred
- r = self.evaluate(size_oblivious=True)
- try:
- return bool(r)
- except Exception:
- log.warning("Failed to convert to bool: %s", r)
- raise
- def guard_or_false(self, file, line):
- from torch.fx.experimental.symbolic_shapes import guard_or_false
- if not self.is_bool():
- raise AssertionError("Expected bool type")
- return guard_or_false(SymBool(self))
- def guard_or_true(self, file, line):
- from torch.fx.experimental.symbolic_shapes import guard_or_true
- if not self.is_bool():
- raise AssertionError("Expected bool type")
- return guard_or_true(SymBool(self))
- def bool_(self):
- return self.guard_bool("", 0)
- def is_symbolic(self):
- return True
- def nested_int(self):
- return None
- def is_constant(self):
- return False
- class _DynamicScalar:
- def __new__(cls, *args):
- if cls is _DynamicScalar:
- raise TypeError("_DynamicScalar is an abstract base class, use DynamicInt.")
- return super().__new__(cls, *args)
- class DynamicInt(_DynamicScalar, int):
- """
- User API for marking dynamic integers in `torch.compile`.
- Intended to be compatible with both compile and eager mode.
- Example usage::
- fn = torch.compile(f)
- x = DynamicInt(4)
- fn(x) # compiles x as a dynamic integer input; returns f(4)
- """
- def __new__(cls, val):
- if not isinstance(val, int):
- raise AssertionError(f"Expected int, got {type(val)}")
- obj = super().__new__(cls, int(val))
- return obj
- def __repr__(self):
- return f"DynamicInt({self.real})"
- def __floordiv__(self, other): # // was casting to int without these overrides?
- return DynamicInt(self.real // other)
- def __rfloordiv__(self, other):
- return DynamicInt(other // self.real)
- # TODO: this probably needs the sizes-strides eval functions
- METHOD_TO_OPERATOR = {
- "pos": operator.pos,
- "abs": operator.abs,
- "add": operator.add,
- "and": operator.and_,
- "bitwise_and": operator.and_,
- "ceil": math.ceil,
- "eq": operator.eq,
- "floor": math.floor,
- "trunc": math.trunc,
- "int_floordiv": operator.floordiv,
- "ge": operator.ge,
- "gt": operator.gt,
- "is_integer": lambda x: x.is_integer(),
- "le": operator.le,
- "lshift": operator.lshift,
- "lt": operator.lt,
- "mod": operator.mod,
- "mul": operator.mul,
- "ne": operator.ne,
- "neg": operator.neg,
- "or": operator.or_,
- "bitwise_or": operator.or_,
- "bitwise_xor": operator.xor,
- "float_pow": operator.pow,
- "pow_by_natural": operator.pow,
- "round": builtins.round,
- "rshift": operator.rshift,
- "sub": operator.sub,
- "sym_float": sym_float,
- "sym_ite": sym_ite,
- "sym_max": sym_max,
- "sym_min": sym_min,
- "sym_not": sym_not,
- "float_truediv": operator.truediv,
- "int_truediv": operator.truediv,
- }
- unary_magic_methods = {
- "abs",
- "sym_float",
- "sym_int",
- "ceil",
- "floor",
- "neg",
- "sym_not",
- "pos",
- "trunc",
- }
- # Adding math ops: sqrt, cos, sin, ...
- def _get_sym_node_fn(name):
- def fn(self):
- return getattr(self, f"_sym_{name}")()
- return fn
- math_op_names = (
- "sqrt",
- "cos",
- "cosh",
- "sin",
- "sinh",
- "tan",
- "tanh",
- "asin",
- "acos",
- "atan",
- "log2",
- )
- for name in math_op_names:
- sym_name = f"sym_{name}"
- priv_sym_name = f"_{sym_name}"
- setattr(SymNode, sym_name, _get_sym_node_fn(name))
- METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name)
- unary_magic_methods.add(sym_name)
- __all__.append(sym_name)
- # Unary methods that are not magic methods
- unary_nonmagic_methods = {
- "is_integer",
- }
- unary_methods = unary_magic_methods | unary_nonmagic_methods
- # Most methods are only registered on SymInt and SymFloat
- # Some methods are only be registered on SymBool
- only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
- # Methods that implicitly convert SymBool into SymInt
- bool_becomes_int_magic_methods = {"add", "sub", "mul"}
- # Methods that are also on SymBool, in addition to on SymInt and SymFloat
- also_bool_magic_methods = {"eq"}
- bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
- # Methods that are only for float
- only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"}
- magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
- # remap necessary because an op name can have a bitwise and boolean implementation
- bitwise_ops = {"bitwise_and": "and", "bitwise_or": "or", "bitwise_xor": "xor"}
- always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"}
- for name in math_op_names:
- sym_name = f"sym_{name}"
- always_float_magic_methods.add(sym_name)
- always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"}
- always_bool_magic_methods = {
- "eq",
- "ne",
- "gt",
- "lt",
- "le",
- "ge",
- "and",
- "or",
- "sym_not",
- "is_non_overlapping_and_dense",
- "is_integer",
- }
- # Methods that have a `__foo__` as well as `__rfoo__`
- def _sympy_float_truediv(a, b):
- from torch.utils._sympy.functions import FloatTrueDiv
- return FloatTrueDiv(a, b)
- def _sympy_int_truediv(a, b):
- from torch.utils._sympy.functions import IntTrueDiv
- return IntTrueDiv(a, b)
- def _sympy_floordiv(a, b):
- from torch.utils._sympy.functions import FloorDiv
- return FloorDiv(a, b)
- def _sympy_mod(a, b):
- from torch.utils._sympy.functions import Mod, PythonMod
- if a.is_nonnegative and b.is_nonnegative:
- return Mod(a, b)
- else:
- return PythonMod(a, b)
- def _sympy_pow_by_natural(a, b):
- from torch.utils._sympy.functions import PowByNatural
- return PowByNatural(a, b)
- def _sympy_float_pow(a, b):
- from torch.utils._sympy.functions import FloatPow
- return FloatPow(a, b)
- def _sympy_and(a, b):
- import sympy
- return sympy.And(a, b)
- def _sympy_or(a, b):
- import sympy
- return sympy.Or(a, b)
- def _sympy_lshift(a, b):
- from torch.utils._sympy.functions import LShift
- return LShift(a, b)
- def _sympy_rshift(a, b):
- from torch.utils._sympy.functions import RShift
- return RShift(a, b)
- def _binary_search_insert_arg(ordered_args, new_arg):
- """
- If new_arg is found in ordered_args None is returned, else the new
- ordered_args with new_arg inserted
- """
- if len(ordered_args) == 0:
- return [new_arg]
- from sympy.core.basic import _args_sortkey as sort_key, Basic
- # Fast path when new_arg > ordered_args[-1].
- if sort_key(ordered_args[-1]) < sort_key(new_arg):
- return ordered_args + [new_arg]
- # Fast path when new_arg < ordered_args[0].
- if sort_key(ordered_args[0]) > sort_key(new_arg):
- return [new_arg] + ordered_args
- low, high = 0, len(ordered_args) - 1
- while low <= high:
- mid = (low + high) // 2
- compare_result = Basic.compare(ordered_args[mid], new_arg)
- if compare_result == 0:
- return None
- elif compare_result < 0:
- low = mid + 1
- else:
- high = mid - 1
- ordered_args.insert(low, new_arg)
- return ordered_args
- def _optimized_add(
- lhs, rhs, lhs_is_optimized_summation=False, rhs_is_optimized_summation=False
- ):
- """
- Custom optimization for Add used to optimize incremental binary summations of certain properties. The idea
- is when we know the expression is a summation of unique symbols all we need to know is the correct order of symbols,
- and no other optimizations are needed. We pass evaluate=false, with the correct order of args and save the following.
- 1. Avoid running other optimizations when the Add is constructed.
- 2. Manually figure out the order of the args for the new expression in log(n) comparisons instead of nLog(n)
- (comparing terms is expensive and shows in the profiles).
- The function returns a tuple of (1) a boolean that indicates whether the output is a summation of unique symbols,
- (2) the result sympy expression.
- """
- import sympy
- from sympy.core.basic import _args_sortkey as sortkey
- def make_optimized(ordered_args):
- if ordered_args is None:
- raise AssertionError("ordered_args is None")
- # Use _from_args directly to bypass _exec_constructor_postprocessors
- # which iterates over all args. This is safe because args are only
- # symbols or constants, which don't register postprocessors.
- # Pass is_commutative=True to avoid fuzzy_and check over all args.
- result = sympy.Add._from_args(ordered_args, is_commutative=True)
- return (True, result)
- from torch.utils._sympy.functions import _is_symbols_binary_summation
- lhs_is_optimized_summation |= _is_symbols_binary_summation(lhs)
- rhs_is_optimized_summation |= _is_symbols_binary_summation(rhs)
- if lhs_is_optimized_summation and rhs_is_optimized_summation:
- # (a0+a1..) + (a2+a3..) => (a0+a1+a2+a3)
- if sortkey(lhs._args[-1]) < sortkey(rhs._args[0]):
- return make_optimized(lhs._args + rhs._args)
- # (a2+a3..) + (a0+a1..) => (a0+a1+a2+a3)
- if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]):
- return make_optimized(rhs._args + lhs._args)
- # (a1+a3) + (a0+a2) => (a0+a1+a2+a3)
- if len(lhs._args) <= 2 and len(rhs._args) <= 2:
- new_args = list(lhs._args)
- for a in rhs._args:
- new_args = _binary_search_insert_arg(new_args, a)
- if new_args is None:
- break
- # None means an element already exists.
- if new_args is not None:
- return make_optimized(new_args)
- # (a0+a2) + a1 => (a0+a1+a2)
- if lhs_is_optimized_summation and rhs.is_symbol:
- new_args = _binary_search_insert_arg(list(lhs._args), rhs)
- # None means an element already exists.
- if new_args is not None:
- return make_optimized(new_args)
- # a1 + (a0+a2)=> (a0+a1+a2)
- if rhs_is_optimized_summation and lhs.is_symbol:
- new_args = _binary_search_insert_arg(list(rhs._args), lhs)
- # None means an element already exists.
- if new_args is not None:
- return make_optimized(new_args)
- result = sympy.Add(lhs, rhs)
- return (_is_symbols_binary_summation(result), result)
- def _bitwise_and(a, b):
- from torch.utils._sympy.functions import BitwiseFn_bitwise_and
- return BitwiseFn_bitwise_and(a, b)
- def _bitwise_or(a, b):
- from torch.utils._sympy.functions import BitwiseFn_bitwise_or
- return BitwiseFn_bitwise_or(a, b)
- def _bitwise_xor(a, b):
- from torch.utils._sympy.functions import BitwiseFn_bitwise_xor
- return BitwiseFn_bitwise_xor(a, b)
- reflectable_magic_methods = {
- "add": operator.add,
- "sub": operator.sub,
- "mul": operator.mul,
- "mod": _sympy_mod,
- "pow_by_natural": _sympy_pow_by_natural,
- "float_pow": _sympy_float_pow,
- "and": _sympy_and,
- "bitwise_and": _bitwise_and,
- "or": _sympy_or,
- "bitwise_or": _bitwise_or,
- "bitwise_xor": _bitwise_xor,
- "float_truediv": _sympy_float_truediv,
- "int_truediv": _sympy_int_truediv,
- "int_floordiv": _sympy_floordiv,
- "lshift": _sympy_lshift,
- "rshift": _sympy_rshift,
- }
- def _floor_ceil_helper(a, fn):
- import sympy
- if isinstance(a, sympy.Mul):
- aa = a.args
- if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
- coef = sympy.Integer(aa[0])
- if aa[0] == coef: # structural equality test
- return coef * aa[1]
- if (
- isinstance(a, sympy.Float)
- and a == sympy.Integer(a)
- or isinstance(a, sympy.Integer)
- ):
- return sympy.Integer(a)
- return fn(a)
- def _sympy_floor(a):
- from torch.utils._sympy.functions import FloorToInt
- return FloorToInt(a)
- # NB: this is Python trunc semantics which returns an int. Do NOT use this to
- # represent torch.trunc (which is float to float)
- def _sympy_trunc(a):
- from torch.utils._sympy.functions import TruncToInt
- return TruncToInt(a)
- def _sympy_ceil(a):
- from torch.utils._sympy.functions import CeilToInt
- return CeilToInt(a)
- def _sympy_eq(a, b):
- import sympy
- return sympy.Eq(a, b)
- def _sympy_ne(a, b):
- import sympy
- return sympy.Ne(a, b)
- def _sympy_gt(a, b):
- import sympy
- return sympy.Gt(a, b)
- def _sympy_lt(a, b):
- import sympy
- return sympy.Lt(a, b)
- def _sympy_le(a, b):
- import sympy
- return sympy.Le(a, b)
- def _sympy_ge(a, b):
- import sympy
- return sympy.Ge(a, b)
- def _sympy_min(a, b):
- from torch.utils._sympy.functions import Min
- return Min(a, b)
- def _sympy_max(a, b):
- from torch.utils._sympy.functions import Max
- return Max(a, b)
- def _sympy_ite(a, t, f):
- import sympy
- return sympy.Piecewise((t, a), (f, True))
- current_module = sys.modules[__name__]
- def _get_sym_math_fn(name):
- def fn(a):
- import torch.utils._sympy.functions
- return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a)
- return fn
- for name in math_op_names:
- priv_sympy_name = f"_sympy_{name}"
- fn = _get_sym_math_fn(name)
- fn.__qualname__ = fn.__name__ = priv_sympy_name
- setattr(current_module, priv_sympy_name, fn)
- del fn, name, priv_sympy_name # type: ignore[possibly-undefined]
- def _sympy_abs(a):
- import sympy
- return sympy.Abs(a)
- def _sympy_round(number, ndigits=None):
- from torch.utils._sympy.functions import RoundDecimal, RoundToInt
- if ndigits is None:
- return RoundToInt(number)
- else:
- return RoundDecimal(number, ndigits)
- def _sympy_sym_float(a):
- from torch.utils._sympy.functions import ToFloat
- # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly
- # reports that it is an integer
- return ToFloat(a)
- def _sympy_is_integer(a):
- import sympy
- from torch.utils._sympy.functions import ToFloat
- return sympy.Eq(ToFloat(sympy.floor(a)), a)
- magic_methods = {
- **reflectable_magic_methods,
- "sym_not": operator.invert,
- "pos": operator.pos,
- "eq": _sympy_eq,
- "ne": _sympy_ne,
- "gt": _sympy_gt,
- "lt": _sympy_lt,
- "le": _sympy_le,
- "ge": _sympy_ge,
- "floor": _sympy_floor,
- "trunc": _sympy_trunc,
- "sym_float": _sympy_sym_float,
- "ceil": _sympy_ceil,
- "neg": operator.neg,
- "sym_min": _sympy_min,
- "sym_max": _sympy_max,
- "sym_ite": _sympy_ite,
- "abs": _sympy_abs,
- "round": _sympy_round,
- "is_integer": _sympy_is_integer,
- }
- for name in math_op_names:
- sym_name = f"sym_{name}"
- magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
- del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined]
- def sympy_is_contiguous(sizes, strides):
- dim = len(sizes)
- return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1)))
- def sympy_is_contiguous_generic(sizes, strides, dim_order):
- import sympy
- dim = len(sizes)
- if len(dim_order) != dim:
- return sympy.false
- is_contiguous = sympy.true
- z = sympy.S.One
- # Contiguous if the strides make sense (or the dim is size 1)
- for d in dim_order:
- is_contiguous &= sympy.Eq(sizes[d], sympy.S.One) | sympy.Eq(strides[d], z)
- z *= sizes[d]
- # OR if any size is zero
- for d in range(dim):
- is_contiguous |= sympy.Eq(sizes[d], sympy.S.Zero)
- return is_contiguous
- # NB: There is a TODO in C++ to allow omitting the batch dim. If that
- # happens you will need to refactor this
- def sympy_is_channels_last_contiguous_2d(sizes, strides):
- return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0])
- def sympy_is_channels_last_contiguous_3d(sizes, strides):
- return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0])
- def sympy_is_channels_last_strides_generic(sizes, strides, dim_order):
- import sympy
- from torch.utils._sympy.functions import Max
- dim = len(sizes)
- if dim != len(dim_order):
- return sympy.false
- m = sympy.S.Zero
- r = sympy.true
- # special case for trivial C dimension. default to NCHW
- r &= sympy.Ne(strides[1], 0)
- for d in dim_order:
- r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m)
- # Fallback to NCHW as default layout for ambiguous cases
- # This is the flaw of implicit memory_format from strides.
- # N111 tensor with identical strides for size 1 dimension;
- # Two cases could lead us here:
- # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
- # b. N11W contiguous Tensor sliced on the W-dimension.
- # ([N,1,1,1]@[W,W,W,W])
- if d == 0:
- r &= sympy.Ne(m, strides[1])
- # This is necessary to:
- # 1. distinguish the memory_format of N1H1;
- # [H, 1, 1, 1] channels_last stride
- # [H, H, 1, 1] contiguous stride
- # 2. permutation of 1C1W:
- # [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
- # [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as
- # channels_last
- m = strides[d] * Max(sizes[d], 1)
- return r
- def sympy_is_channels_last_strides_2d(sizes, strides):
- return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0])
- def sympy_is_channels_last_strides_3d(sizes, strides):
- return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0])
- def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides):
- from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator
- return IsNonOverlappingAndDenseIndicator(*sizes, *strides)
- sizes_strides_methods = {
- # TODO: These could also be done with indicators, maybe it is better
- # for reasoning to do it that way
- "is_contiguous": sympy_is_contiguous,
- "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d,
- "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d,
- "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d,
- "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d,
- "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator,
- }
- def to_node(self, num):
- if isinstance(num, SymTypes):
- return num.node
- elif type(num) is bool:
- return self.wrap_bool(num)
- elif type(num) is int:
- return self.wrap_int(num)
- elif type(num) is float:
- return self.wrap_float(num)
- else:
- # NotImplemented is important so that Python tries the
- # other magic method
- return NotImplemented
- def wrap_node(x):
- # TODO: let C++ also take advantage of this
- if isinstance(x, SymNode) and x.constant is not None:
- return x.constant
- if x.is_int():
- return SymInt(x)
- elif x.is_float():
- return SymFloat(x)
- elif x.is_bool():
- return SymBool(x)
- else:
- raise AssertionError(f"unrecognized return type {x}")
- def method_to_operator(method):
- return METHOD_TO_OPERATOR[method]
- def _make_node_magic(method, func):
- func = lru_cache(256)(func)
- if method in magic_methods_on_operator_with_trailing_underscore:
- method_attr = f"{method}_"
- else:
- method_attr = method
- def uninteresting_files() -> set[str]:
- import torch
- mods = [
- torch._dynamo.eval_frame,
- torch._dynamo.utils,
- torch.fx.experimental.sym_node,
- torch,
- ]
- import torch._dynamo.guards
- return (
- {inspect.getfile(m) for m in mods}
- | torch._dynamo.guards.uninteresting_files()
- | {"<string>"}
- )
- def capture_provenance(fn):
- @functools.wraps(fn)
- def wrapper(self, other=None):
- if other is None:
- result = fn(self)
- else:
- result = fn(self, other)
- if torch._logging._internal.GET_DTRACE_STRUCTURED:
- if other is not None:
- arguments = [self, other]
- else:
- arguments = [self]
- def get_id(sym_node) -> Optional[int]:
- # We don't want to return an ID if the input is a constant
- import sympy
- if sym_node.constant is not None:
- return None
- elif id(sym_node) == id(result):
- return None
- elif isinstance(sym_node.expr, (sympy.Integer, sympy.Float)):
- return None
- elif sym_node.expr in (sympy.true, sympy.false):
- return None
- return id(sym_node)
- dtrace_structured(
- "expression_created",
- metadata_fn=lambda: {
- "method": method,
- "result": str(result),
- "result_id": id(result),
- "arguments": [str(a) for a in arguments],
- "argument_ids": [
- get_id(i) for i in arguments if get_id(i) is not None
- ],
- "user_stack": structured.get_user_stack(3),
- "stack": structured.get_framework_stack(3),
- },
- )
- return result
- return wrapper
- @capture_provenance
- def binary_magic_impl(self, other):
- from torch.fx.experimental.proxy_tensor import (
- get_proxy_mode,
- handle_sym_dispatch,
- )
- op = method_to_operator(method)
- out_hint: object = _NO_HINT
- if self.hint is not None and other.hint is not None:
- out_hint = op(self.hint, other.hint)
- if get_proxy_mode():
- return to_node(
- self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
- )
- if not isinstance(other, SymNode):
- raise AssertionError(f"Expected SymNode, got {type(other)}")
- optimized_summation = False
- try:
- if method == "mod":
- from torch.utils._sympy.functions import Mod, PythonMod
- # Special handling for mod that requires access to the value
- # ranges
- shape_env = self.shape_env
- if (
- self.expr.is_nonnegative
- or shape_env.bound_sympy(self.expr).lower >= 0
- ) and (
- other.expr.is_nonnegative
- or shape_env.bound_sympy(other.expr).lower >= 0
- ):
- out = Mod(self.expr, other.expr)
- else:
- out = PythonMod(self.expr, other.expr)
- elif method == "add":
- # see Note [optimized_summation]
- (optimized_summation, out) = _optimized_add(
- self.expr,
- other.expr,
- self._optimized_summation,
- other._optimized_summation,
- )
- elif method in ("eq", "ne", "ge", "gt", "le", "lt"):
- import sympy
- from torch.utils._sympy.symbol import symbol_is_type, SymT
- # Optimization: when one side is a single unbacked symbol
- # and other is constant, use evaluate=False to skip expensive
- # relational evaluation. We only do this for unbacked symbols
- # because they have no assumptions (like positive=True) that
- # sympy would use during evaluation.
- lhs_is_unbacked = self.expr.is_symbol and symbol_is_type(
- self.expr, SymT.UNBACKED_INT
- )
- rhs_is_unbacked = other.expr.is_symbol and symbol_is_type(
- other.expr, SymT.UNBACKED_INT
- )
- if (lhs_is_unbacked and other.expr.is_number) or (
- rhs_is_unbacked and self.expr.is_number
- ):
- rel_class = {
- "eq": sympy.Eq,
- "ne": sympy.Ne,
- "ge": sympy.Ge,
- "gt": sympy.Gt,
- "le": sympy.Le,
- "lt": sympy.Lt,
- }[method]
- out = rel_class(self.expr, other.expr, evaluate=False)
- else:
- out = func(self.expr, other.expr)
- else:
- # TODO: consider constant prop here
- out = func(self.expr, other.expr)
- except Exception:
- log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
- raise
- sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out)
- pytype: type
- # This is not strictly correct. In Python, a**b may return complex when
- # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This
- # returns a float while both arguments are ints: 2**(-1). Also, max and
- # min do not type promote. To avoid having data-dependent control flow
- # here, we just set the type to float if one of the args is a float. In
- # case of a type mismatch, we assume that it will be detected during
- # evaluation.
- if method in always_float_magic_methods:
- pytype = float
- elif method in always_bool_magic_methods:
- pytype = bool
- elif self.pytype is float or other.pytype is float:
- pytype = float
- else:
- pytype = self.pytype
- if (
- pytype is not None
- and out_hint is not _NO_HINT
- and out_hint is not None
- and not isinstance(out_hint, SymTypes)
- ):
- out_hint = pytype(out_hint) # type: ignore[arg-type]
- # Create a FX node that corresponds to the operation being applied to
- # this node.
- fx_node, _ = self.shape_env._create_fx_call_function(
- op, (self.fx_node, other.fx_node)
- )
- result = SymNode(
- out,
- self.shape_env,
- pytype,
- out_hint, # type: ignore[arg-type]
- fx_node=fx_node,
- optimized_summation=optimized_summation, # see Note [optimized_summation]
- )
- return result
- @capture_provenance
- def unary_magic_impl(self):
- from torch.fx.experimental.proxy_tensor import (
- get_proxy_mode,
- handle_sym_dispatch,
- )
- op = method_to_operator(method)
- if get_proxy_mode():
- return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
- # TODO: consider constant prop here
- expr = self.expr
- if method == "floor" or method == "ceiling":
- expr = self.shape_env._simplify_floor_div(expr)
- try:
- out = func(expr)
- except Exception:
- log.warning("failed to eval %s(%s)", method, expr)
- raise
- sym_node_log.debug("%s %s -> %s", func, expr, out)
- out_hint: object = _NO_HINT
- if self.hint is not None:
- out_hint = op(self.hint)
- pytype: type
- if method in always_int_magic_methods:
- pytype = int
- elif method in always_bool_magic_methods:
- pytype = bool
- elif method in always_float_magic_methods:
- pytype = float
- else:
- pytype = self.pytype
- fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,))
- return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) # type: ignore[arg-type]
- if method in unary_methods:
- setattr(SymNode, f"_{method_attr}", unary_magic_impl)
- elif method == "sym_ite":
- def sym_ite_impl(pred_node, then_node, else_node):
- from torch.fx.experimental.proxy_tensor import (
- get_proxy_mode,
- handle_sym_dispatch,
- )
- out_hint = then_node.hint if pred_node.hint else else_node.hint
- if get_proxy_mode():
- return to_node(
- pred_node,
- handle_sym_dispatch(
- sym_ite,
- (
- wrap_node(pred_node),
- wrap_node(then_node),
- wrap_node(else_node),
- ),
- {},
- ),
- )
- try:
- out = func(pred_node.expr, then_node.expr, else_node.expr)
- except Exception:
- log.warning(
- "failed to eval %s(%s, %s, %s)",
- method,
- pred_node.expr,
- then_node.expr,
- else_node.expr,
- )
- raise
- fx_node, _ = pred_node.shape_env._create_fx_call_function(
- sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node)
- )
- return SymNode(
- out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node
- )
- setattr(SymNode, f"_{method_attr}", sym_ite_impl)
- elif method == "round":
- def round_impl(self, ndigits=None):
- from torch.fx.experimental.proxy_tensor import (
- get_proxy_mode,
- handle_sym_dispatch,
- )
- op = builtins.round
- if get_proxy_mode():
- return to_node(
- self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
- )
- expr = self.expr
- try:
- out = func(expr, ndigits)
- except Exception:
- log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits)
- raise
- if ndigits is None:
- pytype = int
- else:
- pytype = self.pytype
- out_hint = None
- if self.hint is not None:
- out_hint = op(self.hint, ndigits)
- # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the
- # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here
- # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The
- # hack down below works, because all round function down the line all take ndigits=None as default in their
- # signature.
- # TODO: Remove the args construction below if a different sentinel is used by FX.
- # ezyang(May 2024): LOL
- args = [self.fx_node]
- if ndigits is not None:
- args.append(ndigits)
- fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args))
- return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
- setattr(SymNode, f"_{method_attr}", round_impl)
- else:
- setattr(SymNode, f"_{method_attr}", binary_magic_impl)
- def _make_node_sizes_strides(method, func):
- # NB: don't LRU cache, lots of arguments
- def sizes_strides_impl(self, sizes, strides):
- from torch.fx.experimental.proxy_tensor import (
- get_proxy_mode,
- handle_sym_dispatch,
- )
- op = getattr(sys.modules[__name__], method)
- if get_proxy_mode():
- return to_node(
- self,
- handle_sym_dispatch(
- op,
- ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
- {},
- ),
- )
- size_exprs = [s.expr for s in sizes]
- stride_exprs = [s.expr for s in strides]
- try:
- out = func(size_exprs, stride_exprs)
- except Exception:
- log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs)
- raise
- # bool is never expandable
- size_hints = []
- out_hint = None
- for s in sizes:
- if s.hint is None:
- break
- size_hints.append(s.hint)
- else:
- stride_hints = []
- for s in strides:
- if s.hint is None:
- break
- stride_hints.append(s.hint)
- else:
- out_hint = op(size_hints, stride_hints)
- # NB: This is the indicator function, not the actual bool!
- pytype: type
- if method.endswith("_indicator"):
- pytype = int
- else:
- pytype = bool
- return SymNode(out, self.shape_env, pytype, out_hint)
- setattr(SymNode, f"_{method}", sizes_strides_impl)
- # TODO: This is technically hotpath, but in the ideal end state
- # guards on this will resolve at a higher level so you never
- # spend time in this code
- def sizes_strides_user(sizes, strides):
- import sympy
- from torch.fx.experimental.symbolic_shapes import (
- eval_is_non_overlapping_and_dense,
- )
- for a in itertools.chain(sizes, strides):
- if isinstance(a, SymInt):
- return wrap_node(
- getattr(a.node, method)(
- [to_node(a.node, b) for b in sizes],
- [to_node(a.node, b) for b in strides],
- )
- )
- if method == "is_non_overlapping_and_dense_indicator":
- return eval_is_non_overlapping_and_dense(sizes, strides)
- else:
- # TODO: this is an awful implementation
- return bool(
- func(
- [sympy.sympify(a) for a in sizes],
- [sympy.sympify(a) for a in strides],
- )
- )
- # Skip for is_non_overlapping_and_dense_indicator
- if not hasattr(sys.modules[__name__], method):
- setattr(sys.modules[__name__], method, sizes_strides_user)
- for method, func in magic_methods.items():
- _make_node_magic(method, func)
- for method, func in sizes_strides_methods.items():
- _make_node_sizes_strides(method, func)
- def _make_user_magic(method, user_type):
- # User magic takes care of wrapping the other operand into a node,
- # so that our internal logic can assume everything is nodes
- if method in magic_methods_on_operator_with_trailing_underscore:
- method_attr = f"sym_{method}"
- else:
- method_attr = method
- def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]):
- if isinstance(x, (int, float, bool)):
- return x
- if isinstance(x, SymInt):
- return x.node.guard_int("", 0)
- if isinstance(x, SymBool):
- return x.node.guard_bool("", 0)
- raise AssertionError("expect to be called with constant SymBools")
- def is_constant(x):
- if isinstance(x, (int, float, bool)):
- return True
- if isinstance(x, (SymInt, SymFloat, SymBool)):
- return x.node.is_constant()
- return False
- # Promotion rules for binary operations. NB: we preserve PYTHON semantics
- # - if args are same type, do nothing
- # - if one arg is float, promote other arg to float
- # - nb: this applies to floordiv, even though output is integral
- # (it's still float)
- # - pow is funny business
- # - if both ints
- # - trigger a guard on exponent >= 0
- # - if non-negative, output is int
- # - otherwise, output is float
- # - otherwise, promote other arg to float
- # - nb: complex is impossible to handle correctly lol, with
- # negative base and integral float need to diverge semantics and
- # just always return complex. Neener neener pretend this problem
- # doesn't exist
- # - equality is pain: Python does the fancy thing where it unpacks the
- # mantissa from the float and then compares that against the int.
- # Which means it is able to tell that
- # 9007199254740993 != 9007199254740992. (rather than if the LHS was
- # promoted to float, in which case it would have truncated to the RHS
- # and subsequently been equal). We'll model this exactly by having
- # special mixed type equality operations. Unfortunately, we need to
- # do this for all comparison operations (maybe I'll only implement
- # compare)
- # - sym_ite mumble mumble really shouldn't allow mixed but whatever
- if method in bool_becomes_int_magic_methods:
- def promote(x):
- """Implements True+True=2, which works in python but not sympy"""
- if isinstance(x, SymBool):
- return SymInt(x.node.wrap_int(int(x)))
- return x
- else:
- def promote(x):
- return x
- def promote2(self, other):
- # TODO: Remove eq and other relations from this list.
- # CPython has fancy implementations for these to get as much precision
- # as possible instead of just promoting to float64 and praying, so we
- # need to handle them specially too.
- # Also, note that int_truediv doesn't go through this path: both
- # arguments are "int" so there isn't any promotion
- if method not in [
- "add",
- "sub",
- "mul",
- "mod",
- "float_pow",
- "float_truediv",
- "int_floordiv",
- "sym_min",
- "sym_max",
- # TODO: remove these
- "eq",
- "ne",
- "gt",
- "lt",
- "le",
- "ge",
- ]:
- return self, other
- f_self = isinstance(self, (float, torch.SymFloat))
- f_other = isinstance(other, (float, torch.SymFloat))
- if f_self or f_other:
- if not f_self:
- self = torch.sym_float(self)
- if not f_other:
- other = torch.sym_float(other)
- return self, other
- # Before and after performing the operation, check if any operands are constant.
- # If so, extract out the constant values first. If `self` itself is a
- # constant, then "redispatch" by calling back into the operator. Sometimes
- # this means that operations involving SymBool return plain bools.
- # Alternatively, we could also rewrap into constant Symbool (i.e. by
- # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that
- # today for no particular reason.
- def unary_magic_impl(self):
- self = promote(self)
- if is_constant(self):
- return (method_to_operator(method))(get_constant(self))
- return wrap_node(getattr(self.node, method_attr)())
- def binary_magic_impl(self, other):
- if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
- return NotImplemented
- sym_node_log.debug("MAGIC %s %s %s", method, self, other)
- self = promote(self)
- other = promote(other)
- self, other = promote2(self, other)
- if is_constant(self):
- return (method_to_operator(method))(get_constant(self), other)
- if is_constant(other):
- other = get_constant(other)
- other_node = to_node(self.node, other)
- if other_node is NotImplemented:
- return NotImplemented
- ret = wrap_node(getattr(self.node, method_attr)(other_node))
- return get_constant(ret) if is_constant(ret) else ret
- def rbinary_magic_impl(self, other):
- if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
- return NotImplemented
- self = promote(self)
- other = promote(other)
- self, other = promote2(self, other)
- if is_constant(self):
- return (method_to_operator(method))(other, get_constant(self))
- if is_constant(other):
- other = get_constant(other)
- other_node = to_node(self.node, other)
- if other_node is NotImplemented:
- return NotImplemented
- ret = wrap_node(getattr(other_node, method_attr)(self.node))
- return get_constant(ret) if is_constant(ret) else ret
- def setattrs(user_type, attr, symnode_impl):
- """
- Registers the SymNode magic method on SymInt/Float/Bool,
- and optionally registers a corresponding wrapped method on DynamicInt.
- """
- # SymInt/Float/Bool
- setattr(user_type, attr, symnode_impl)
- # DynamicInt impl
- def dynamic_int_impl(*args):
- args = [x.real if isinstance(x, DynamicInt) else x for x in args]
- out = getattr(int, attr)(*args)
- if isinstance(out, int) and not isinstance(out, bool):
- return DynamicInt(out)
- return out
- if user_type is SymInt:
- setattr(DynamicInt, attr, dynamic_int_impl)
- if method in unary_magic_methods:
- setattrs(user_type, f"__{method}__", unary_magic_impl)
- elif method in unary_nonmagic_methods:
- orig = getattr(user_type, method)
- setattrs(user_type, method, update_wrapper(unary_magic_impl, orig))
- elif method == "sym_ite":
- def sym_ite_magic_impl(pred, then_val, else_val):
- pred_node = pred.node
- then_node = to_node(pred_node, then_val)
- else_node = to_node(pred_node, else_val)
- if then_node is NotImplemented or else_node is NotImplemented:
- return NotImplemented
- if not (
- isinstance(then_node, SymNode)
- and isinstance(else_node, SymNode)
- and then_node.pytype == else_node.pytype
- ):
- raise AssertionError(
- "then_node and else_node must be SymNodes with same pytype"
- )
- ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
- return get_constant(ret) if ret.node.is_constant() else ret
- setattrs(user_type, f"__{method}__", sym_ite_magic_impl)
- elif method == "round":
- def round_magic_impl(self, ndigits=None):
- if is_constant(self):
- return builtins.round(get_constant(self), ndigits)
- return wrap_node(getattr(self.node, method)(ndigits))
- setattrs(user_type, f"__{method}__", round_magic_impl)
- else:
- method_name = method
- if method in bitwise_ops:
- method_name = bitwise_ops[method]
- setattrs(user_type, f"__{method_name}__", binary_magic_impl)
- if method in reflectable_magic_methods:
- setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl)
- for method in magic_methods: # type: ignore[assignment]
- if method in only_bool_magic_methods:
- _make_user_magic(method, SymBool)
- continue
- if method in only_float_magic_methods:
- _make_user_magic(method, SymFloat)
- continue
- if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods:
- _make_user_magic(method, SymBool)
- _make_user_magic(method, SymInt)
- if method not in bitwise_ops:
- _make_user_magic(method, SymFloat)
- del method
- del func
|