| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417 |
- # mypy: allow-untyped-defs
- import mpmath.libmp as mlib # type: ignore[import-untyped]
- import sympy
- from sympy import Expr
- from sympy.core.decorators import _sympifyit
- from sympy.core.expr import AtomicExpr
- from sympy.core.numbers import Number
- from sympy.core.parameters import global_parameters
- from sympy.core.singleton import S, Singleton
- # pyrefly: ignore [invalid-inheritance]
- class IntInfinity(Number, metaclass=Singleton):
- r"""Positive integer infinite quantity.
- Integer infinity is a value in an extended integers which
- is greater than all other integers. We distinguish it from
- sympy's existing notion of infinity in that it reports that
- it is_integer.
- Infinity is a singleton, and can be accessed by ``S.IntInfinity``,
- or can be imported as ``int_oo``.
- """
- # NB: We can't actually mark this as infinite, as integer and infinite are
- # inconsistent assumptions in sympy. We also report that we are complex,
- # different from sympy.oo
- is_integer = True
- is_commutative = True
- is_number = True
- is_extended_real = True
- is_comparable = True
- is_extended_positive = True
- is_prime = False
- # Ensure we get dispatched to before plain numbers
- _op_priority = 100.0
- __slots__ = ()
- def __new__(cls):
- return AtomicExpr.__new__(cls)
- def _sympystr(self, printer) -> str:
- return "int_oo"
- def _eval_subs(self, old, new):
- if self == old:
- return new
- # We could do these, not sure about it
- """
- def _eval_evalf(self, prec=None):
- return Float('inf')
- def evalf(self, prec=None, **options):
- return self._eval_evalf(prec)
- """
- @_sympifyit("other", NotImplemented)
- def __add__(self, other):
- if isinstance(other, Number) and global_parameters.evaluate:
- if other in (S.Infinity, S.NegativeInfinity):
- return other
- if other in (S.NegativeIntInfinity, S.NaN):
- return S.NaN
- return self
- return Number.__add__(self, other)
- __radd__ = __add__
- @_sympifyit("other", NotImplemented)
- def __sub__(self, other):
- if isinstance(other, Number) and global_parameters.evaluate:
- if other is S.Infinity:
- return S.NegativeInfinity
- if other is S.NegativeInfinity:
- return S.Infinity
- if other in (S.IntInfinity, S.NaN):
- return S.NaN
- return self
- return Number.__sub__(self, other)
- @_sympifyit("other", NotImplemented)
- def __rsub__(self, other):
- return (-self).__add__(other)
- @_sympifyit("other", NotImplemented)
- def __mul__(self, other):
- if isinstance(other, Number) and global_parameters.evaluate:
- if other.is_zero or other is S.NaN:
- return S.NaN
- if other.is_extended_positive:
- return self
- return S.NegativeIntInfinity
- return Number.__mul__(self, other)
- __rmul__ = __mul__
- @_sympifyit("other", NotImplemented)
- def __truediv__(self, other):
- if isinstance(other, Number) and global_parameters.evaluate:
- if other in (
- S.Infinity,
- S.IntInfinity,
- S.NegativeInfinity,
- S.NegativeIntInfinity,
- S.NaN,
- ):
- return S.NaN
- if other.is_extended_nonnegative:
- return S.Infinity # truediv produces float
- return S.NegativeInfinity # truediv produces float
- return Number.__truediv__(self, other)
- def __abs__(self):
- return S.IntInfinity
- def __neg__(self):
- return S.NegativeIntInfinity
- def _eval_power(self, expt):
- if expt.is_extended_positive:
- return S.IntInfinity
- if expt.is_extended_negative:
- return S.Zero
- if expt is S.NaN:
- return S.NaN
- if expt is S.ComplexInfinity:
- return S.NaN
- if expt.is_extended_real is False and expt.is_number:
- from sympy.functions.elementary.complexes import re
- expt_real = re(expt)
- if expt_real.is_positive:
- return S.ComplexInfinity
- if expt_real.is_negative:
- return S.Zero
- if expt_real.is_zero:
- return S.NaN
- return self ** expt.evalf()
- def _as_mpf_val(self, prec):
- return mlib.finf
- def __hash__(self):
- return super().__hash__()
- def __eq__(self, other):
- return other is S.IntInfinity
- def __ne__(self, other):
- return other is not S.IntInfinity
- def __gt__(self, other):
- if other is S.Infinity:
- return sympy.false # sympy.oo > int_oo
- elif other is S.IntInfinity:
- return sympy.false # consistency with sympy.oo
- else:
- return sympy.true
- def __ge__(self, other):
- if other is S.Infinity:
- return sympy.false # sympy.oo > int_oo
- elif other is S.IntInfinity:
- return sympy.true # consistency with sympy.oo
- else:
- return sympy.true
- def __lt__(self, other):
- if other is S.Infinity:
- return sympy.true # sympy.oo > int_oo
- elif other is S.IntInfinity:
- return sympy.false # consistency with sympy.oo
- else:
- return sympy.false
- def __le__(self, other):
- if other is S.Infinity:
- return sympy.true # sympy.oo > int_oo
- elif other is S.IntInfinity:
- return sympy.true # consistency with sympy.oo
- else:
- return sympy.false
- @_sympifyit("other", NotImplemented)
- def __mod__(self, other):
- if not isinstance(other, Expr):
- return NotImplemented
- return S.NaN
- __rmod__ = __mod__
- def floor(self):
- return self
- def ceiling(self):
- return self
- int_oo = S.IntInfinity
- def is_infinite(expr) -> bool:
- """Check if an expression is any type of infinity (positive or negative).
- This handles both sympy's built-in infinities (oo, -oo) and PyTorch's
- integer infinities (int_oo, -int_oo).
- Note: We cannot rely on sympy's is_finite property because IntInfinity
- and NegativeIntInfinity have is_integer=True, which implies is_finite=True
- in sympy's assumption system.
- """
- return expr in (
- S.Infinity,
- S.NegativeInfinity,
- S.IntInfinity,
- S.NegativeIntInfinity,
- )
- # pyrefly: ignore [invalid-inheritance]
- class NegativeIntInfinity(Number, metaclass=Singleton):
- """Negative integer infinite quantity.
- NegativeInfinity is a singleton, and can be accessed
- by ``S.NegativeInfinity``.
- See Also
- ========
- IntInfinity
- """
- # Ensure we get dispatched to before plain numbers
- _op_priority = 100.0
- is_integer = True
- is_extended_real = True
- is_commutative = True
- is_comparable = True
- is_extended_negative = True
- is_number = True
- is_prime = False
- __slots__ = ()
- def __new__(cls):
- return AtomicExpr.__new__(cls)
- def _eval_subs(self, old, new):
- if self == old:
- return new
- def _sympystr(self, printer) -> str:
- return "-int_oo"
- """
- def _eval_evalf(self, prec=None):
- return Float('-inf')
- def evalf(self, prec=None, **options):
- return self._eval_evalf(prec)
- """
- @_sympifyit("other", NotImplemented)
- def __add__(self, other):
- if isinstance(other, Number) and global_parameters.evaluate:
- if other is S.Infinity:
- return S.Infinity
- if other in (S.IntInfinity, S.NaN):
- return S.NaN
- return self
- return Number.__add__(self, other)
- __radd__ = __add__
- @_sympifyit("other", NotImplemented)
- def __sub__(self, other):
- if isinstance(other, Number) and global_parameters.evaluate:
- if other is S.NegativeInfinity:
- return S.Infinity
- if other in (S.NegativeIntInfinity, S.NaN):
- return S.NaN
- return self
- return Number.__sub__(self, other)
- @_sympifyit("other", NotImplemented)
- def __rsub__(self, other):
- return (-self).__add__(other)
- @_sympifyit("other", NotImplemented)
- def __mul__(self, other):
- if isinstance(other, Number) and global_parameters.evaluate:
- if other.is_zero or other is S.NaN:
- return S.NaN
- if other.is_extended_positive:
- return self
- return S.IntInfinity
- return Number.__mul__(self, other)
- __rmul__ = __mul__
- @_sympifyit("other", NotImplemented)
- def __truediv__(self, other):
- if isinstance(other, Number) and global_parameters.evaluate:
- if other in (
- S.Infinity,
- S.IntInfinity,
- S.NegativeInfinity,
- S.NegativeIntInfinity,
- S.NaN,
- ):
- return S.NaN
- if other.is_extended_nonnegative:
- return self
- return S.Infinity # truediv returns float
- return Number.__truediv__(self, other)
- def __abs__(self):
- return S.IntInfinity
- def __neg__(self):
- return S.IntInfinity
- def _eval_power(self, expt):
- if expt.is_number:
- if expt in (
- S.NaN,
- S.Infinity,
- S.NegativeInfinity,
- S.IntInfinity,
- S.NegativeIntInfinity,
- ):
- return S.NaN
- if isinstance(expt, sympy.Integer) and expt.is_extended_positive:
- if expt.is_odd:
- return S.NegativeIntInfinity
- else:
- return S.IntInfinity
- inf_part = S.IntInfinity**expt
- s_part = S.NegativeOne**expt
- if inf_part == 0 and s_part.is_finite:
- return inf_part
- if (
- inf_part is S.ComplexInfinity
- and s_part.is_finite
- and not s_part.is_zero
- ):
- return S.ComplexInfinity
- return s_part * inf_part
- def _as_mpf_val(self, prec):
- return mlib.fninf
- def __hash__(self):
- return super().__hash__()
- def __eq__(self, other):
- return other is S.NegativeIntInfinity
- def __ne__(self, other):
- return other is not S.NegativeIntInfinity
- def __gt__(self, other):
- if other is S.NegativeInfinity:
- return sympy.true # -sympy.oo < -int_oo
- elif other is S.NegativeIntInfinity:
- return sympy.false # consistency with sympy.oo
- else:
- return sympy.false
- def __ge__(self, other):
- if other is S.NegativeInfinity:
- return sympy.true # -sympy.oo < -int_oo
- elif other is S.NegativeIntInfinity:
- return sympy.true # consistency with sympy.oo
- else:
- return sympy.false
- def __lt__(self, other):
- if other is S.NegativeInfinity:
- return sympy.false # -sympy.oo < -int_oo
- elif other is S.NegativeIntInfinity:
- return sympy.false # consistency with sympy.oo
- else:
- return sympy.true
- def __le__(self, other):
- if other is S.NegativeInfinity:
- return sympy.false # -sympy.oo < -int_oo
- elif other is S.NegativeIntInfinity:
- return sympy.true # consistency with sympy.oo
- else:
- return sympy.true
- @_sympifyit("other", NotImplemented)
- def __mod__(self, other):
- if not isinstance(other, Expr):
- return NotImplemented
- return S.NaN
- __rmod__ = __mod__
- def floor(self):
- return self
- def ceiling(self):
- return self
- def as_powers_dict(self):
- return {S.NegativeOne: 1, S.IntInfinity: 1}
|