| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509 |
- # mypy: allow-untyped-defs
- import functools
- import math
- import operator
- import sys
- from collections.abc import Callable
- from typing import SupportsFloat, TYPE_CHECKING, TypeVar
- from typing_extensions import TypeVarTuple, Unpack
- import sympy
- from sympy import S
- from sympy.core import sympify
- from sympy.core.expr import Expr
- from sympy.core.function import Application
- from sympy.core.logic import _torf, fuzzy_and, fuzzy_or
- from sympy.core.numbers import equal_valued
- from sympy.core.operations import LatticeOp, ShortCircuit
- from sympy.core.sorting import ordered
- from sympy.core.traversal import walk
- from sympy.printing.precedence import PRECEDENCE
- from sympy.utilities.iterables import sift
- from torch.torch_version import TorchVersion
- from .numbers import int_oo, is_infinite
- if TYPE_CHECKING:
- from collections.abc import Iterable
- _T = TypeVar("_T", bound=SupportsFloat)
- _Ts = TypeVarTuple("_Ts")
- # Portions of this file are adapted from the Sympy codebase, which was
- # licensed as follows:
- #
- # Copyright (c) 2006-2023 SymPy Development Team
- #
- # All rights reserved.
- #
- # Redistribution and use in source and binary forms, with or without
- # modification, are permitted provided that the following conditions are met:
- #
- # a. Redistributions of source code must retain the above copyright notice,
- # this list of conditions and the following disclaimer.
- # b. Redistributions in binary form must reproduce the above copyright
- # notice, this list of conditions and the following disclaimer in the
- # documentation and/or other materials provided with the distribution.
- # c. Neither the name of SymPy nor the names of its contributors
- # may be used to endorse or promote products derived from this software
- # without specific prior written permission.
- #
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
- # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
- # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
- # ARE DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE FOR
- # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
- # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
- # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
- # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
- # LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
- # OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
- # DAMAGE.
- __all__ = [
- "FloorDiv",
- "ModularIndexing",
- "Where",
- "PythonMod",
- "Mod",
- "CleanDiv",
- "CeilToInt",
- "FloorToInt",
- "CeilDiv",
- "IntTrueDiv",
- "FloatTrueDiv",
- "LShift",
- "RShift",
- "IsNonOverlappingAndDenseIndicator",
- "TruncToFloat",
- "TruncToInt",
- "RoundToInt",
- "RoundDecimal",
- "ToFloat",
- "FloatPow",
- "PowByNatural",
- "Identity",
- ]
- def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
- # No need to check that two args are not the same, since expr is pr-optimized but we do it anyway.
- return (
- isinstance(expr, sympy.Expr)
- and expr.is_Add
- and len(expr._args) == 2
- and expr._args[0].is_symbol
- and expr._args[1].is_symbol
- and expr._args[0] is not expr._args[1]
- )
- def _keep_float(
- f: Callable[[Unpack[_Ts]], _T],
- ) -> Callable[[Unpack[_Ts]], _T | sympy.Float]:
- @functools.wraps(f)
- def inner(*args: Unpack[_Ts]) -> _T | sympy.Float:
- r: _T | sympy.Float = f(*args)
- if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
- r, sympy.Float
- ):
- r = sympy.Float(float(r))
- return r
- # pyrefly: ignore [bad-return]
- return inner
- def fuzzy_eq(x: bool | None, y: bool | None) -> bool | None:
- if None in (x, y):
- return None
- return x == y
- def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic:
- """
- Fast path for sympy.gcd, using a simple factoring strategy.
- We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0,
- where n is the greatest common integer factor and e is the largest
- syntactic common factor (i.e., common sub-expression) in p and q.
- Then the gcd returned is n*e, cancelling which we would be left with
- p1 + p2 and q0.
- Note that further factoring of p1 + p2 and q0 might be possible with
- sympy.factor (which uses domain-specific theories). E.g., we are unable
- to find that x*y + x + y + 1 is divisible by x + 1. More generally,
- when q is of the form q1 + q2 (instead of being already factored) it
- might be necessary to fall back on sympy.gcd.
- """
- def integer_coefficient(x: sympy.Basic) -> int:
- integer_coefficients: list[int] = [
- abs(int(arg))
- for arg in sympy.Mul.make_args(x)
- if isinstance(arg, (int, sympy.Integer))
- ]
- return math.prod(integer_coefficients)
- def integer_factor(expr: sympy.Basic) -> int:
- integer_factors: Iterable[int] = map(
- integer_coefficient, sympy.Add.make_args(expr)
- )
- return functools.reduce(math.gcd, integer_factors)
- gcd: int = math.gcd(integer_factor(p), integer_factor(q))
- p, q = p / gcd, q / gcd # type: ignore[operator, assignment] # remove in py3.12
- base_splits: list[tuple[sympy.Basic, ...]] = list(
- map(sympy.Mul.make_args, sympy.Add.make_args(p))
- )
- divisor_split: tuple[sympy.Basic, ...] = sympy.Mul.make_args(q)
- for x in divisor_split:
- if all(x in base_split for base_split in base_splits):
- gcd = gcd * x # type: ignore[operator] # remove in py3.12
- return gcd # type: ignore[return-value] # remove in py3.12
- # It would be nice to have assertions on whether or not inputs is_integer
- # However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy
- # sometimes inconsistently reports floats an integers.
- #
- # What we can assume from sympy is that if something is an int, it
- # definitely is is_integer, but if it is a float it may or may not
- # be is_integer. So we are unable to do strong asserts that things
- # are NOT integers.
- # TODO: In Triton, // rounds to zero, but in Python, it is floor division.
- # When we can prove both arguments are non-negative, we should just have a
- # GenericFloorDiv (name pending) which can codegen efficiently in Python/C,
- # and then PythonFloorDiv and CIntDiv which have the appropriate rounding
- # semantics.
- #
- # Right now, FloorDiv de facto changes behavior if arguments are negative or
- # not, this can potentially cause correctness issues.
- class FloorDiv(sympy.Function):
- """
- We maintain this so that:
- 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b.
- 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b)
- NB: This is Python-style floor division, round to -Inf
- """
- nargs: tuple[int, ...] = (2,)
- precedence: int = 35 # lower precedence than add
- is_integer: bool = True
- @property
- def base(self) -> sympy.Basic:
- # pyrefly: ignore [missing-attribute]
- return self.args[0]
- @property
- def divisor(self) -> sympy.Basic:
- # pyrefly: ignore [missing-attribute]
- return self.args[1]
- def _sympystr(self, printer: sympy.printing.StrPrinter) -> str:
- base = printer.parenthesize(self.base, PRECEDENCE["Atom"] - 0.5)
- divisor = printer.parenthesize(self.divisor, PRECEDENCE["Atom"] - 0.5)
- return f"({base}//{divisor})"
- # Automatic evaluation.
- # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval
- @classmethod
- def eval(cls, base: sympy.Integer, divisor: sympy.Integer) -> sympy.Basic | None:
- # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
- # Assert triggered by inequality solver
- # assert base.is_integer, base
- # assert divisor.is_integer, divisor
- # We don't provide the same error message as in Python because SymPy
- # makes it difficult to check the types.
- if divisor.is_zero:
- raise ZeroDivisionError("division by zero")
- if is_infinite(base) and is_infinite(divisor):
- return sympy.nan
- if base is sympy.nan or divisor is sympy.nan:
- return sympy.nan
- if base.is_zero:
- return sympy.S.Zero
- if base.is_integer and equal_valued(divisor, 1):
- return base
- if base.is_integer and equal_valued(divisor, -1):
- return sympy.Mul(base, -1)
- if base == divisor:
- return sympy.S.One
- if (
- isinstance(base, sympy.Number)
- and isinstance(divisor, sympy.Number)
- and (is_infinite(base) or is_infinite(divisor))
- ):
- r = float(base) / float(divisor)
- if r == math.inf:
- return int_oo
- elif r == -math.inf:
- return -int_oo
- elif math.isnan(r):
- return sympy.nan
- else:
- return sympy.Integer(math.floor(r))
- if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
- return sympy.Integer(int(base) // int(divisor))
- if isinstance(base, FloorDiv):
- return FloorDiv(base.args[0], base.args[1] * divisor)
- # Expands (x + y) // b into x // b + y // b.
- # This only works if floor is an identity, i.e. x / b is an integer.
- if isinstance(divisor, sympy.Integer):
- quotients = 0
- terms = []
- for term in sympy.Add.make_args(base):
- quotient = term / divisor
- # This is a sympy bug fixed in https://github.com/sympy/sympy/pull/28442
- # sympy can generate a quotient with (1/22)*.... such that quotient.is_integer is True
- # FloorDiv should not allow that as output. see
- quotient_is_integer = None
- if isinstance(quotient, sympy.Mul) and TorchVersion(
- sympy.__version__
- ) < TorchVersion("1.15.0"):
- rationals = quotient.atoms(sympy.Rational)
- all_rationals_ints = all(r.q == 1 for r in rationals)
- quotient_is_integer = quotient.is_integer and all_rationals_ints
- else:
- quotient_is_integer = quotient.is_integer
- if quotient_is_integer:
- terms.append(term)
- quotients += quotient
- if len(terms) != 0:
- # Passing evaluate = False since expression will be optimized during the subtraction post its construction.
- return (
- FloorDiv(base - sympy.Add(*terms, evaluate=False), divisor)
- + quotients
- )
- try:
- gcd = simple_floordiv_gcd(base, divisor)
- if equal_valued(gcd, 1) and isinstance(divisor, sympy.Add):
- gcd = sympy.gcd(base, divisor)
- if not equal_valued(gcd, 1):
- return FloorDiv(
- sympy.simplify(base / gcd), sympy.simplify(divisor / gcd)
- )
- except sympy.PolynomialError:
- pass # https://github.com/pytorch/pytorch/issues/108276
- return None
- def _eval_is_nonnegative(self) -> bool | None:
- # pyrefly: ignore [missing-attribute]
- p, q = self.args[:2]
- if all([p.is_integer, q.is_integer, p.is_nonnegative, q.is_nonnegative]):
- return True
- return None
- class ModularIndexing(sympy.Function):
- """
- ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus
- """
- nargs: tuple[int, ...] = (3,)
- is_integer: bool = True
- precedence: int = 35 # lower precedence than add
- @classmethod
- def eval(
- cls, base: sympy.Integer, divisor: sympy.Integer, modulus: sympy.Integer
- ) -> sympy.Basic | None:
- if base == 0 or modulus == 1:
- return sympy.S.Zero
- if (
- isinstance(base, sympy.Integer)
- and isinstance(divisor, sympy.Integer)
- and isinstance(modulus, sympy.Integer)
- ):
- return (base // divisor) % modulus
- try:
- if divisor != 1:
- gcd = sympy.gcd(base, divisor)
- if gcd != 1:
- return ModularIndexing(
- sympy.simplify(base / gcd),
- sympy.simplify(divisor / gcd),
- modulus,
- )
- except sympy.PolynomialError:
- pass # https://github.com/pytorch/pytorch/issues/108276
- if isinstance(base, sympy.Add):
- new_terms: list[sympy.Integer] = []
- all_positive: bool = True
- for term in base.args:
- if sympy.gcd(term, modulus * divisor) != modulus * divisor:
- if (isinstance(term, sympy.Integer) and term < 0) or (
- isinstance(term, sympy.Mul)
- and isinstance(term.args[0], sympy.Integer)
- and term.args[0] < 0
- ):
- # workaround for https://github.com/triton-lang/triton/issues/619,
- # if there are negative terms, // produces wrong result
- # TODO if https://github.com/triton-lang/triton/issues/619 is fixed
- # this optimization would become valid
- all_positive = False
- break
- else:
- new_terms.append(term)
- if len(new_terms) != len(base.args) and all_positive:
- return ModularIndexing(sum(new_terms), divisor, modulus)
- if isinstance(base, FloorDiv):
- return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
- return None
- def _eval_is_nonnegative(self) -> bool | None:
- # pyrefly: ignore [missing-attribute]
- p, q = self.args[:2]
- return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
- class Where(sympy.Function):
- """
- Good ol' ternary operator
- """
- nargs: tuple[int, ...] = (3,)
- precedence: int = 35 # lower precedence than add
- def _eval_is_integer(self) -> bool | None:
- return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
- def _eval_is_nonnegative(self) -> bool | None:
- return (
- True
- if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined]
- else None
- )
- def _eval_is_positive(self) -> bool | None:
- return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined]
- @classmethod
- def eval(cls, c: sympy.Basic, p: sympy.Basic, q: sympy.Basic) -> sympy.Basic | None:
- if c == sympy.true:
- return p
- elif c == sympy.false:
- return q
- return None
- # Python-style modulus: take sign from RHS
- class PythonMod(sympy.Function):
- nargs: tuple[int, ...] = (2,)
- precedence: int = 35 # lower precedence than add
- is_integer: bool = True
- @classmethod
- def eval(cls, p: sympy.Expr, q: sympy.Expr) -> sympy.Expr | None:
- # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint
- # Triggered by sympy.solvers.inequalities.reduce_inequalities
- # assert p.is_integer, p
- # assert q.is_integer, q
- if q.is_zero:
- raise ZeroDivisionError("Modulo by zero")
- # Three cases:
- # 1. p == 0
- # 2. p is either q or -q
- # 3. p is integer and q == 1
- if p is S.Zero or p in (q, -q) or q == 1:
- return S.Zero
- # Evaluate if they are both literals.
- if q.is_Number and p.is_Number:
- return p % q
- # If q == 2, it's a matter of whether p is odd or even.
- if q.is_Number and q == 2:
- if p.is_even:
- return S.Zero
- if p.is_odd:
- return S.One
- # If p is a multiple of q.
- r = p / q
- if r.is_integer:
- return S.Zero
- # If p < q and its ratio is positive, then:
- # - floor(p / q) = 0
- # - p % q = p - floor(p / q) * q = p
- less = p < q
- # pyrefly: ignore [missing-attribute]
- if less.is_Boolean and bool(less) and r.is_positive:
- return p
- if sympy.Mod(p, q) == 0:
- return S.Zero
- return None
- # NB: args[1] for PythonMod
- def _eval_is_nonnegative(self) -> bool | None:
- return True if self.args[1].is_positive else None # type: ignore[attr-defined]
- def _eval_is_nonpositive(self) -> bool | None:
- return True if self.args[1].is_negative else None # type: ignore[attr-defined]
- def _ccode(self, printer) -> str:
- # pyrefly: ignore [missing-attribute]
- p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
- # pyrefly: ignore [missing-attribute]
- q = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
- # pyrefly: ignore [missing-attribute]
- abs_q = str(q) if self.args[1].is_positive else f"abs({q})"
- return f"({p} % {q}) < 0 ? {p} % {q} + {abs_q} : {p} % {q}"
- # Generic modulus: only defined on non-negative arguments
- class Mod(sympy.Function):
- nargs = (2,)
- precedence: int = 35 # lower precedence than add
- is_integer = True
- is_nonnegative = True
- @classmethod
- def eval(cls, p, q):
- # This was adapted from: sympy/core/mod.py
- # Triggered by
- # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full
- # assert p.is_integer, p
- # assert q.is_integer, q
- if q.is_zero:
- raise ZeroDivisionError("Modulo by zero")
- # Three cases:
- # 1. p == 0
- # 2. p is either q or -q
- # 3. p is integer and q == 1
- if p is S.Zero or p in (q, -q) or q == 1:
- return S.Zero
- # Evaluate if they are both literals.
- if q.is_Number and p.is_Number:
- if p < 0:
- raise AssertionError(p)
- if q < 1:
- raise AssertionError(q)
- return p % q
- # If q == 2, it's a matter of whether p is odd or even.
- if q.is_Number and q == 2:
- if p.is_even:
- return S.Zero
- if p.is_odd:
- return S.One
- # If p is a multiple of q.
- r = p / q
- if r.is_integer:
- return S.Zero
- # If p < q and its ratio is positive, then:
- # - floor(p / q) = 0
- # - p % q = p - floor(p / q) * q = p
- less = p < q
- if less.is_Boolean and bool(less) and r.is_positive:
- return p
- class CleanDiv(FloorDiv):
- """
- Div where we can assume no rounding.
- This is to enable future optimizations.
- """
- # Don't use sympy ceiling/floor as they will attempt simplifications involving
- # frac
- class CeilToInt(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, number):
- # assert number.is_integer is not True, number
- if number in (sympy.oo, int_oo):
- return int_oo
- if number in (-sympy.oo, -int_oo):
- return -int_oo
- if isinstance(number, sympy.Number):
- return sympy.Integer(math.ceil(float(number)))
- def _ccode(self, printer) -> str:
- # pyrefly: ignore [missing-attribute]
- number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5)
- return f"ceil({number})"
- class FloorToInt(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, number):
- if number in (sympy.oo, int_oo):
- return int_oo
- if number in (-sympy.oo, int_oo):
- return -int_oo
- if isinstance(number, sympy.Integer):
- return number
- if isinstance(number, sympy.Number):
- return sympy.Integer(math.floor(float(number)))
- class CeilDiv(sympy.Function):
- """
- Div used in indexing that rounds up.
- """
- is_integer = True
- def __new__(cls, base, divisor):
- base = sympy.sympify(base)
- divisor = sympy.sympify(divisor)
- if sympy.gcd(base, divisor) == divisor:
- return CleanDiv(base, divisor)
- else:
- return FloorDiv(base + (divisor - 1), divisor)
- class LShift(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, base, shift):
- if shift < 0:
- raise ValueError("negative shift count")
- return base * 2**shift
- class RShift(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, base, shift):
- if shift < 0:
- raise ValueError("negative shift count")
- return FloorDiv(base, 2**shift)
- class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
- def __new__(cls, *original_args, **assumptions):
- from sympy.core.parameters import global_parameters
- evaluate = assumptions.pop("evaluate", global_parameters.evaluate)
- args = (sympify(arg) for arg in original_args)
- # See the comment in _satisfy_unique_summations_symbols.
- unique_summations_symbols = (
- None
- if not evaluate
- else cls._satisfy_unique_summations_symbols(original_args)
- )
- if evaluate:
- try:
- # first standard filter, for cls.zero and cls.identity
- # also reshape Max(a, Max(b, c)) to Max(a, b, c)
- args = frozenset(cls._new_args_filter(args)) # type: ignore[assignment]
- except ShortCircuit:
- return cls.zero # type: ignore[attr-defined]
- # No need to run _collapse_arguments and _find_localzeros, see the comment
- # in _satisfy_unique_summations_symbols.
- if unique_summations_symbols is None:
- # remove redundant args that are easily identified
- args = cls._collapse_arguments(args, **assumptions)
- # find local zeros
- args = cls._find_localzeros(args, **assumptions)
- args = frozenset(args)
- if not args:
- return cls.identity # type: ignore[attr-defined]
- if len(args) == 1:
- return list(args).pop()
- # base creation
- obj = Expr.__new__(cls, *ordered(args), **assumptions)
- obj._argset = args
- obj.unique_summations_symbols = unique_summations_symbols
- return obj
- @classmethod
- def _satisfy_unique_summations_symbols(
- cls, args
- ) -> set[sympy.core.symbol.Symbol] | None:
- """
- One common case in some models is building expressions of the form
- max(max(max(a+b...), c+d), e+f) which is simplified to max(a+b, c+d, e+f, ...).
- For such expressions, we call the Max constructor X times (once for each nested
- max) and the expression gets flattened.
- An expensive cost in constructing those expressions is running _collapse_arguments
- and _find_localzeros. However, those two optimizations are unnecessary when the args
- to max are all of the form a+b, c+d, ..etc where each term uses a unique set of symbols.
- This function is used to detect such properties of the expressions we are building
- and if so inform that we do not need to run those optimizations. To detect those,
- we store a property in the expression that tells that this expression is a min/max
- operation over terms that use unique symbols "unique_summations_symbols". This property
- also memoize the set of symbols used in all the terms to make it faster to detect this
- property inductively.
- When we apply max to add a new term, all we need to do is check if the new term uses
- unique symbols (with respect to existing terms and itself).
- Example:
- t = Max(a+b, c+d) ==> satisfies the property
- Max(t, h+j) ==> h,j not in [a,b,c,d] => satisfy the property.
- The function returns None if the new expression does not satisfy the unique_summations_symbols
- property. Otherwise, it returns a new set of unique symbols.
- """
- if len(args) != 2:
- return None
- (lhs, rhs) = (
- (args[1], args[0])
- if isinstance(args[1], MinMaxBase)
- else (args[0], args[1])
- )
- if not _is_symbols_binary_summation(rhs):
- return None
- # base case max(a+b, c+d) ==> satisfies the property if a+b and c+d use unique symbols.
- if _is_symbols_binary_summation(lhs):
- return cls._unique_symbols(args)
- # inductive case max(t, h+j) ==> satisfies the property if h, j not in t.unique_summations_symbols
- if isinstance(lhs, MinMaxBase):
- lhs_unique_summations_symbols = getattr(
- lhs, "unique_summations_symbols", None
- )
- if lhs_unique_summations_symbols is not None:
- return cls._unique_symbols([rhs], lhs_unique_summations_symbols)
- return None
- @classmethod
- def _unique_symbols(
- cls, args, initial_set: set[sympy.core.symbol.Symbol] | None = None
- ) -> set[sympy.core.symbol.Symbol] | None:
- """
- Return seen_symbols if all atoms in all args are all unique symbols,
- else returns None. initial_set can be used to represent initial value for seen_symbols
- """
- seen_symbols = set() if initial_set is None else initial_set
- for arg in args:
- for element in arg.atoms():
- if not isinstance(element, sympy.core.symbol.Symbol):
- return None
- elif element in seen_symbols:
- return None
- else:
- seen_symbols.add(element)
- return seen_symbols
- @classmethod
- def _collapse_arguments(cls, args, **assumptions):
- """Remove redundant args.
- Examples
- ========
- >>> from sympy import Min, Max
- >>> from sympy.abc import a, b, c, d, e
- Any arg in parent that appears in any
- parent-like function in any of the flat args
- of parent can be removed from that sub-arg:
- >>> Min(a, Max(b, Min(a, c, d)))
- Min(a, Max(b, Min(c, d)))
- If the arg of parent appears in an opposite-than parent
- function in any of the flat args of parent that function
- can be replaced with the arg:
- >>> Min(a, Max(b, Min(c, d, Max(a, e))))
- Min(a, Max(b, Min(a, c, d)))
- """
- if not args:
- return args
- args = list(ordered(args))
- if cls is Min:
- other = Max
- else:
- other = Min # type: ignore[assignment]
- # find global comparable max of Max and min of Min if a new
- # value is being introduced in these args at position 0 of
- # the ordered args
- if args[0].is_number:
- sifted = mins, maxs = [], [] # type: ignore[var-annotated]
- for i in args:
- for v in walk(i, Min, Max):
- if v.args[0].is_comparable:
- sifted[isinstance(v, Max)].append(v)
- small = Min.identity
- for i in mins:
- v = i.args[0]
- if v.is_number and (v < small) == True: # noqa: E712
- small = v
- big = Max.identity
- for i in maxs:
- v = i.args[0]
- if v.is_number and (v > big) == True: # noqa: E712
- big = v
- # at the point when this function is called from __new__,
- # there may be more than one numeric arg present since
- # local zeros have not been handled yet, so look through
- # more than the first arg
- if cls is Min:
- for arg in args:
- if not arg.is_number:
- break
- if (arg < small) == True: # noqa: E712
- small = arg
- elif cls == Max:
- for arg in args:
- if not arg.is_number:
- break
- if (arg > big) == True: # noqa: E712
- big = arg
- T = None
- if cls is Min:
- if small != Min.identity:
- other = Max
- T = small
- elif big != Max.identity:
- other = Min # type: ignore[assignment]
- T = big
- if T is not None:
- # remove numerical redundancy
- for i in range(len(args)):
- a = args[i]
- if isinstance(a, other):
- a0 = a.args[0]
- if ( # noqa: E712
- (a0 > T) if other == Max else (a0 < T) # noqa: E712
- ) == True: # noqa: E712
- args[i] = cls.identity # type: ignore[attr-defined]
- # remove redundant symbolic args
- def do(ai, a):
- if not isinstance(ai, (Min, Max)):
- return ai
- cond = a in ai.args
- if not cond:
- return ai.func(*[do(i, a) for i in ai.args], evaluate=False)
- if isinstance(ai, cls):
- # pyrefly: ignore [missing-attribute]
- return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False)
- return a
- for i, a in enumerate(args):
- args[i + 1 :] = [do(ai, a) for ai in args[i + 1 :]]
- # factor out common elements as for
- # Min(Max(x, y), Max(x, z)) -> Max(x, Min(y, z))
- # and vice versa when swapping Min/Max -- do this only for the
- # easy case where all functions contain something in common;
- # trying to find some optimal subset of args to modify takes
- # too long
- def factor_minmax(args):
- is_other = lambda arg: isinstance(arg, other) # noqa: E731
- other_args, remaining_args = sift(args, is_other, binary=True)
- if not other_args:
- return args
- # Min(Max(x, y, z), Max(x, y, u, v)) -> {x,y}, ({z}, {u,v})
- arg_sets = [set(arg.args) for arg in other_args]
- common = set.intersection(*arg_sets)
- if not common:
- return args
- new_other_args = list(common)
- arg_sets_diff = [arg_set - common for arg_set in arg_sets]
- # If any set is empty after removing common then all can be
- # discarded e.g. Min(Max(a, b, c), Max(a, b)) -> Max(a, b)
- if all(arg_sets_diff):
- other_args_diff = [other(*s, evaluate=False) for s in arg_sets_diff]
- new_other_args.append(cls(*other_args_diff, evaluate=False))
- other_args_factored = other(*new_other_args, evaluate=False)
- return remaining_args + [other_args_factored]
- if len(args) > 1:
- args = factor_minmax(args)
- return args
- @classmethod
- def _new_args_filter(cls, arg_sequence):
- """
- Generator filtering args.
- first standard filter, for cls.zero and cls.identity.
- Also reshape ``Max(a, Max(b, c))`` to ``Max(a, b, c)``,
- and check arguments for comparability
- """
- for arg in arg_sequence:
- # pre-filter, checking comparability of arguments
- if (
- not isinstance(arg, Expr)
- or arg.is_extended_real is False
- or (arg.is_number and not arg.is_comparable)
- ):
- raise ValueError(f"The argument '{arg}' is not comparable.")
- if arg == cls.zero: # type: ignore[attr-defined]
- raise ShortCircuit(arg)
- elif arg == cls.identity: # type: ignore[attr-defined]
- continue
- elif arg.func == cls:
- yield from arg.args
- else:
- yield arg
- @classmethod
- def _find_localzeros(cls, values, **options):
- """
- Sequentially allocate values to localzeros.
- When a value is identified as being more extreme than another member it
- replaces that member; if this is never true, then the value is simply
- appended to the localzeros.
- Unlike the sympy implementation, we only look for zero and one, we don't
- do generic is connected test pairwise which is slow
- """
- # First, collapse all numeric arguments
- other_values = set()
- num_value = None
- for arg in values:
- if arg.is_Number:
- if num_value is None:
- num_value = arg
- else:
- if cls is Max:
- num_value = max(num_value, arg)
- elif cls is Min:
- num_value = min(num_value, arg)
- else:
- raise AssertionError(f"impossible {cls}")
- else:
- other_values.add(arg)
- # Special cases when there is only one symbolic value
- if num_value is None:
- return other_values
- if len(other_values) == 0:
- return {num_value}
- if len(other_values) == 1:
- other_value = next(iter(other_values))
- if num_value in (0.0, 0) and other_value.is_nonnegative:
- return other_values if cls is Max else {num_value}
- if num_value == 1 and other_value.is_positive:
- return other_values if cls is Max else {num_value}
- other_values.add(num_value)
- return other_values
- _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731
- _eval_is_antihermitian = lambda s: _torf( # noqa: E731
- i.is_antihermitian
- for i in s.args # noqa: E731
- ) # noqa: E731
- _eval_is_commutative = lambda s: _torf( # noqa: E731
- i.is_commutative
- for i in s.args # noqa: E731
- ) # noqa: E731
- _eval_is_complex = lambda s: _torf(i.is_complex for i in s.args) # noqa: E731
- _eval_is_composite = lambda s: _torf(i.is_composite for i in s.args) # noqa: E731
- _eval_is_even = lambda s: _torf(i.is_even for i in s.args) # noqa: E731
- _eval_is_finite = lambda s: _torf(i.is_finite for i in s.args) # noqa: E731
- _eval_is_hermitian = lambda s: _torf(i.is_hermitian for i in s.args) # noqa: E731
- _eval_is_imaginary = lambda s: _torf(i.is_imaginary for i in s.args) # noqa: E731
- _eval_is_infinite = lambda s: _torf(i.is_infinite for i in s.args) # noqa: E731
- _eval_is_integer = lambda s: _torf(i.is_integer for i in s.args) # noqa: E731
- _eval_is_irrational = lambda s: _torf(i.is_irrational for i in s.args) # noqa: E731
- _eval_is_negative = lambda s: _torf(i.is_negative for i in s.args) # noqa: E731
- _eval_is_noninteger = lambda s: _torf(i.is_noninteger for i in s.args) # noqa: E731
- _eval_is_nonnegative = lambda s: _torf( # noqa: E731
- i.is_nonnegative
- for i in s.args # noqa: E731
- ) # noqa: E731
- _eval_is_nonpositive = lambda s: _torf( # noqa: E731
- i.is_nonpositive
- for i in s.args # noqa: E731
- ) # noqa: E731
- _eval_is_nonzero = lambda s: _torf(i.is_nonzero for i in s.args) # noqa: E731
- _eval_is_odd = lambda s: _torf(i.is_odd for i in s.args) # noqa: E731
- _eval_is_polar = lambda s: _torf(i.is_polar for i in s.args) # noqa: E731
- _eval_is_positive = lambda s: _torf(i.is_positive for i in s.args) # noqa: E731
- _eval_is_prime = lambda s: _torf(i.is_prime for i in s.args) # noqa: E731
- _eval_is_rational = lambda s: _torf(i.is_rational for i in s.args) # noqa: E731
- _eval_is_real = lambda s: _torf(i.is_real for i in s.args) # noqa: E731
- _eval_is_extended_real = lambda s: _torf( # noqa: E731
- i.is_extended_real
- for i in s.args # noqa: E731
- ) # noqa: E731
- _eval_is_transcendental = lambda s: _torf( # noqa: E731
- i.is_transcendental
- for i in s.args # noqa: E731
- ) # noqa: E731
- _eval_is_zero = lambda s: _torf(i.is_zero for i in s.args) # noqa: E731
- class Max(MinMaxBase, Application): # type: ignore[misc]
- r"""
- Return, if possible, the maximum value of the list.
- """
- zero = S.Infinity
- identity = S.NegativeInfinity
- def _eval_is_positive(self): # type:ignore[override]
- return fuzzy_or(a.is_positive for a in self.args) # type: ignore[attr-defined]
- def _eval_is_nonnegative(self): # type:ignore[override]
- return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
- def _eval_is_negative(self): # type:ignore[override]
- # pyrefly: ignore [missing-attribute]
- return fuzzy_and(a.is_negative for a in self.args)
- class Min(MinMaxBase, Application): # type: ignore[misc]
- """
- Return, if possible, the minimum value of the list.
- """
- zero = S.NegativeInfinity
- identity = S.Infinity
- def _eval_is_positive(self): # type:ignore[override]
- return fuzzy_and(a.is_positive for a in self.args) # type: ignore[attr-defined]
- def _eval_is_nonnegative(self): # type:ignore[override]
- return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
- def _eval_is_negative(self): # type:ignore[override]
- # pyrefly: ignore [missing-attribute]
- return fuzzy_or(a.is_negative for a in self.args)
- def safe_pow(base, exp):
- sign = 1
- if base < 0:
- base = -base
- sign = 1 if exp % 2 == 0 else -1
- return sign * _safe_pow(base, exp)
- # Prevent people from overflowing pow
- def _safe_pow(base, exponent):
- if exponent < 0:
- raise ValueError("Exponent must be non-negative.")
- if exponent == 0:
- return 1
- half_exp = safe_pow(base, exponent // 2)
- if half_exp is int_oo:
- return int_oo
- # TODO: microoptimization is to avoid overflowing into arbitrary precision
- # and detect overflow prior to doing operations
- result = half_exp * half_exp
- if result > sys.maxsize:
- return int_oo
- if exponent % 2 == 1:
- result *= base
- if result > sys.maxsize:
- return int_oo
- return result
- class PowByNatural(sympy.Function):
- is_integer = True
- precedence: int = 50 # precedence of mul
- @classmethod
- def eval(cls, base, exp):
- if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
- r = safe_pow(base, exp)
- if r in (-int_oo, int_oo):
- return r
- return sympy.Integer(r)
- if isinstance(exp, sympy.Integer):
- # Rely on regular sympy Pow for this (note that iterated
- # multiplication turns into a Pow anyway, you can't escape!!)
- return sympy.Pow(base, exp)
- if exp in (int_oo, sympy.oo):
- if base.is_nonnegative:
- return int_oo
- elif base.is_negative:
- return sympy.zoo # this is apparently what (-2)**sympy.oo does
- # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp
- # is a natural number if we do
- # base is assumed to be nonnegative, thereby prevent complex numbers from
- # occurring
- class FloatPow(sympy.Function):
- is_real = True
- precedence: int = 60 # precedence of pow
- @classmethod
- def eval(cls, base, exp):
- # NB: These test sympy.Number, not sympy.Float, because:
- # - Sometimes we may have sympy.oo or int_oo, and that's not a Float
- # (but coerces to math.Inf)
- # - Sometimes Float(0.0) will unpredictably decay to Integer(0),
- # but we should still accept it in floatey contexts
- if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number):
- return sympy.Float(float(base) ** float(exp))
- # NB: do not do any nontrivial reasoning
- # Overloaded to be compatible with regular Python.
- # https://github.com/pytorch/pytorch/issues/90900
- #
- # In particular, sympy division is willing to simplify x/x == 1
- # where 1 is an integer, but this must be a float if x was float.
- class FloatTrueDiv(sympy.Function):
- is_real = True
- precedence: int = 35 # lower precedence than add
- @classmethod
- def eval(cls, base, divisor):
- # assert base.is_integer is not True, base
- # assert divisor.is_integer is not True, divisor
- if divisor.is_zero:
- raise ZeroDivisionError("division by zero")
- if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number):
- return sympy.Float(float(base) / float(divisor))
- # Overloaded to be compatible with regular Python. We distinguish this from
- # FloatTrueDiv, because the code generation has to be different for this case:
- # Python has a fancy algorithm for integer true division that isn't just
- # "promote both arguments to float and use float division", so you need to
- # codegen it differently. While technically you can work it out from the
- # types of the input, this is often inconvenient to do in Inductor codegen,
- # so just have a different operator
- # NB: Right now, Inductor codegen doesn't implement this correctly lol
- class IntTrueDiv(sympy.Function):
- is_real = True
- precedence: int = 35 # lower precedence than add
- @classmethod
- def eval(cls, base, divisor):
- if divisor.is_zero:
- raise ZeroDivisionError("division by zero")
- if (
- isinstance(base, sympy.Number)
- and isinstance(divisor, sympy.Number)
- and (is_infinite(base) or is_infinite(divisor))
- ):
- # Don't have to worry about precision here, you're getting zero or
- # inf from the division
- return sympy.Float(float(base) / float(divisor))
- if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer):
- return sympy.Float(int(base) / int(divisor))
- def _ccode(self, printer) -> str:
- # pyrefly: ignore [missing-attribute]
- base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
- # pyrefly: ignore [missing-attribute]
- divisor = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
- return f"((int){base}/(int){divisor})"
- # TODO: As an indicator, this != 0 implies == 1 (and vice versa).
- # Because we do not have the ability to guard on the stride permutation
- # at the moment, it is hard to make further inferences when this is true,
- # as although we know the tensor is contiguous in *some* layout, we don't
- # know which one (however, you could, for example, make the inference that
- # reshaping this to a 1D tensor can be guard-free.)
- class IsNonOverlappingAndDenseIndicator(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, *args):
- if len(args) % 2 != 0:
- raise AssertionError(
- f"expected an even number of arguments, got {len(args)}"
- )
- dim = len(args) // 2
- sizes = args[0:dim]
- strides = args[dim:]
- # sym_node imported in torch.__init__. Local import to avoid an import cycle
- from torch.fx.experimental.symbolic_shapes import (
- eval_is_non_overlapping_and_dense,
- )
- if all(isinstance(a, sympy.Integer) for a in args):
- return eval_is_non_overlapping_and_dense(
- [int(a) for a in sizes], [int(a) for a in strides]
- )
- if dim == 1:
- # Manually implement the rank one short circuit
- if strides[0].is_Number and strides[0] == 1:
- return 1
- if sizes[0].is_Number and sizes[0] < 2:
- return 1
- # return 0 case covered by case above
- # TODO: Inability to access size-obliviousness sucks: if we have a
- # size oblivious test on a size-like unbacked SymInt, we could
- # confidently return zero when we have a size-like u0 stride
- # and a size-like u1 size. Maybe a fancy ValueRanges analysis for
- # this function could help figure this out.
- if all(isinstance(a, sympy.Integer) for a in strides):
- if dim == 0:
- raise AssertionError("dim must not be zero")
- # When all strides are integral, we can sort, and the size for the
- # largest stride doesn't matter and can be arbitrarily symbolic
- s_sizes, s_strides = zip(
- *sorted(zip(sizes, strides, strict=True), key=operator.itemgetter(1)),
- strict=True,
- )
- # Put something arbitrary in the max size spot, it'll be ignored
- if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]):
- s_sizes = s_sizes[:-1] + (42,)
- # We can reuse the regular eval, because it is invariant to
- # permutation of dimensions
- return eval_is_non_overlapping_and_dense(
- [int(a) for a in s_sizes], [int(a) for a in s_strides]
- )
- return None
- # NB: this is inconsistent with math.trunc in Python
- class TruncToFloat(sympy.Function):
- is_real = True
- @classmethod
- def eval(cls, number):
- if number in (sympy.oo, -sympy.oo):
- return number
- # assert number.is_integer is not True, number
- if isinstance(number, sympy.Number):
- # NB: It is safe to use truncation to integer, which is what
- # math.trunc does, as Python integers are arbitrary precision and
- # so we are guaranteed not to lose precision when we do this
- return sympy.Float(math.trunc(float(number)))
- class TruncToInt(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, number):
- # assert number.is_integer is not True, number
- if number in (sympy.oo, int_oo):
- return int_oo
- if number in (-sympy.oo, -int_oo):
- return -int_oo
- if isinstance(number, sympy.Number):
- return sympy.Integer(math.trunc(float(number)))
- # This is float -> int
- class RoundToInt(sympy.Function):
- is_integer = True
- @classmethod
- def eval(cls, number):
- # assert number.is_integer is not True, number
- if number is sympy.oo:
- return int_oo
- if number is -sympy.oo:
- return -int_oo
- if isinstance(number, sympy.Number):
- return sympy.Integer(round(float(number), 0))
- # To get float -> int, Python style round semantics.
- #
- # x = PyFloat_AsDouble(self);
- # if (o_ndigits == Py_None) {
- # /* single-argument round or with None ndigits:
- # * round to nearest integer */
- # rounded = round(x);
- # if (fabs(x-rounded) == 0.5)
- # /* halfway case: round to even */
- # rounded = 2.0*round(x/2.0);
- # return PyLong_FromDouble(rounded);
- # }
- # NB: Like Round, this only ever returns floats. ndigits cannot be None
- class RoundDecimal(sympy.Function):
- is_real = True
- @classmethod
- def eval(cls, number, ndigits):
- # assert number.is_integer is not True, number
- if isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer):
- return sympy.Float(round(float(number), int(ndigits)))
- class ToFloat(sympy.Function):
- is_real = True
- @classmethod
- def eval(cls, number):
- if number in [sympy.oo, -sympy.oo]:
- return number
- if isinstance(number, sympy.Integer):
- return sympy.Float(int(number))
- if number is int_oo:
- return sympy.oo
- if number is -int_oo:
- return -sympy.oo
- class Identity(sympy.Function):
- """
- Prevents expansion and other optimizations
- """
- precedence = 10
- def __repr__(self) -> str: # type: ignore[override]
- # pyrefly: ignore [missing-attribute]
- return f"Identity({self.args[0]})"
- def _sympystr(self, printer) -> str:
- """Controls how sympy's StrPrinter prints this"""
- # pyrefly: ignore [missing-attribute]
- return f"({printer.doprint(self.args[0])})"
- def _eval_is_real(self):
- # pyrefly: ignore [missing-attribute]
- return self.args[0].is_real
- def _eval_is_integer(self):
- return self.args[0].is_integer # type: ignore[attr-defined]
- @property
- def is_number(self):
- # Treat Identity as numeric only when the argument is comparable.
- # This avoids creating numeric non-comparable Identity(I) terms.
- # pyrefly: ignore [missing-attribute]
- return bool(self.args[0].is_number and self.args[0].is_comparable)
- @property
- def is_comparable(self):
- # Delegate comparability to the wrapped argument.
- # pyrefly: ignore [missing-attribute]
- return bool(self.args[0].is_comparable)
- def _eval_expand_identity(self, **hints):
- # Removes the identity op.
- # pyrefly: ignore [missing-attribute]
- return self.args[0]
- def __int__(self) -> int:
- # pyrefly: ignore [missing-attribute]
- return int(self.args[0])
- def _identity_atom_compare(self, other, op):
- """
- Fast path for comparing wrapped numeric atomics against other numeric atomics.
- Keep compound expressions on SymPy's default symbolic path.
- """
- arg = self.args[0]
- if isinstance(other, int):
- other = sympy.Integer(other)
- if not isinstance(other, sympy.Expr):
- return None
- if not (arg.is_Atom and arg.is_number and arg.is_comparable):
- return None
- if not (other.is_Atom and other.is_number and other.is_comparable):
- return None
- return sympy.S.true if op(arg, other) else sympy.S.false
- def __ge__(self, other):
- out = self._identity_atom_compare(other, lambda a, b: a >= b)
- return out if out is not None else super().__ge__(other)
- def __gt__(self, other):
- out = self._identity_atom_compare(other, lambda a, b: a > b)
- return out if out is not None else super().__gt__(other)
- def __le__(self, other):
- out = self._identity_atom_compare(other, lambda a, b: a <= b)
- return out if out is not None else super().__le__(other)
- def __lt__(self, other):
- out = self._identity_atom_compare(other, lambda a, b: a < b)
- return out if out is not None else super().__lt__(other)
- def __float__(self) -> float:
- # pyrefly: ignore [missing-attribute]
- return float(self.args[0])
- def make_opaque_unary_fn(name):
- class OpaqueUnaryFn(sympy.Function):
- """
- Unlike the builtin sympy functions on real numbers like sympy.sqrt,
- these equivalents do not do any nontrivial reasoning besides
- constant propagation. This helps avoid performing transformations
- that are valid for real numbers but are invalid for floating point;
- in particular, while we are willing to make optimizations that change
- numerics for Tensor compute, we are NOT willing to make optimizations
- that change numerics for size compute.
- """
- _torch_handler_name = name
- _torch_unpickler = make_opaque_unary_fn
- @classmethod
- def eval(cls, a):
- if isinstance(a, (sympy.Integer, sympy.Float)):
- # Python converts to float64 before computing, c.f.
- # >>> math.sin(2**53+1)
- # -0.848925964814655
- # >>> math.sin(float(2**53+1))
- # -0.848925964814655
- try:
- return sympy.Float(getattr(math, name)(float(a)))
- # Just use sympy semantics for infinity/overflow, you might get some
- # weird objects but ask silly questions, get silly answers
- except OverflowError:
- return getattr(sympy, name)(a)
- elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo, int_oo, -int_oo]:
- if a is int_oo:
- a = sympy.oo
- if a is -int_oo:
- a = -sympy.oo
- if name == "log2":
- return sympy.log(a, 2)
- return getattr(sympy, name)(a)
- return None
- nm = "OpaqueUnaryFn_" + name
- OpaqueUnaryFn.__name__ = nm
- OpaqueUnaryFn.__qualname__ = nm
- return OpaqueUnaryFn
- # Keep in sync with math_op_names in torch/fx/experimental/sym_node.py
- OpaqueUnaryFn_sqrt = make_opaque_unary_fn("sqrt")
- OpaqueUnaryFn_cos = make_opaque_unary_fn("cos")
- OpaqueUnaryFn_cosh = make_opaque_unary_fn("cosh")
- OpaqueUnaryFn_sin = make_opaque_unary_fn("sin")
- OpaqueUnaryFn_sinh = make_opaque_unary_fn("sinh")
- OpaqueUnaryFn_tan = make_opaque_unary_fn("tan")
- OpaqueUnaryFn_tanh = make_opaque_unary_fn("tanh")
- OpaqueUnaryFn_asin = make_opaque_unary_fn("asin")
- OpaqueUnaryFn_acos = make_opaque_unary_fn("acos")
- OpaqueUnaryFn_atan = make_opaque_unary_fn("atan")
- OpaqueUnaryFn_exp = make_opaque_unary_fn("exp")
- OpaqueUnaryFn_log = make_opaque_unary_fn("log")
- OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh")
- OpaqueUnaryFn_log2 = make_opaque_unary_fn("log2")
- def make_opaque_bitwise_fn(name, real_op_name):
- if name == "bitwise_and":
- prec = PRECEDENCE["BitwiseAnd"]
- elif name == "bitwise_xor":
- prec = PRECEDENCE["BitwiseXor"]
- elif name == "bitwise_or":
- prec = PRECEDENCE["BitwiseOr"]
- else:
- raise AssertionError(f"unrecognized {name}")
- class BitwiseFn(sympy.Function):
- _torch_handler_name = name
- precedence: int = prec
- _torch_unpickler = functools.partial(
- make_opaque_bitwise_fn, real_op_name=real_op_name
- )
- @classmethod
- def eval(cls, a, b):
- if a.is_Boolean and b.is_Boolean:
- return getattr(operator, real_op_name)(a, b)
- if a.is_Boolean:
- a = sympy.Integer(1 if a else 0)
- if b.is_Boolean:
- b = sympy.Integer(1 if b else 0)
- if isinstance(a, (sympy.Integer, int)) and isinstance(
- b, (sympy.Integer, int)
- ):
- return sympy.Integer(getattr(operator, real_op_name)(int(a), int(b)))
- return None
- nm = "BitwiseFn_" + name
- BitwiseFn.__name__ = nm
- BitwiseFn.__qualname__ = nm
- return BitwiseFn
- BitwiseFn_bitwise_and = make_opaque_bitwise_fn("bitwise_and", "and_")
- BitwiseFn_bitwise_or = make_opaque_bitwise_fn("bitwise_or", "or_")
- BitwiseFn_bitwise_xor = make_opaque_bitwise_fn("bitwise_xor", "xor")
|