numbers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # mypy: allow-untyped-defs
  2. import mpmath.libmp as mlib # type: ignore[import-untyped]
  3. import sympy
  4. from sympy import Expr
  5. from sympy.core.decorators import _sympifyit
  6. from sympy.core.expr import AtomicExpr
  7. from sympy.core.numbers import Number
  8. from sympy.core.parameters import global_parameters
  9. from sympy.core.singleton import S, Singleton
  10. # pyrefly: ignore [invalid-inheritance]
  11. class IntInfinity(Number, metaclass=Singleton):
  12. r"""Positive integer infinite quantity.
  13. Integer infinity is a value in an extended integers which
  14. is greater than all other integers. We distinguish it from
  15. sympy's existing notion of infinity in that it reports that
  16. it is_integer.
  17. Infinity is a singleton, and can be accessed by ``S.IntInfinity``,
  18. or can be imported as ``int_oo``.
  19. """
  20. # NB: We can't actually mark this as infinite, as integer and infinite are
  21. # inconsistent assumptions in sympy. We also report that we are complex,
  22. # different from sympy.oo
  23. is_integer = True
  24. is_commutative = True
  25. is_number = True
  26. is_extended_real = True
  27. is_comparable = True
  28. is_extended_positive = True
  29. is_prime = False
  30. # Ensure we get dispatched to before plain numbers
  31. _op_priority = 100.0
  32. __slots__ = ()
  33. def __new__(cls):
  34. return AtomicExpr.__new__(cls)
  35. def _sympystr(self, printer) -> str:
  36. return "int_oo"
  37. def _eval_subs(self, old, new):
  38. if self == old:
  39. return new
  40. # We could do these, not sure about it
  41. """
  42. def _eval_evalf(self, prec=None):
  43. return Float('inf')
  44. def evalf(self, prec=None, **options):
  45. return self._eval_evalf(prec)
  46. """
  47. @_sympifyit("other", NotImplemented)
  48. def __add__(self, other):
  49. if isinstance(other, Number) and global_parameters.evaluate:
  50. if other in (S.Infinity, S.NegativeInfinity):
  51. return other
  52. if other in (S.NegativeIntInfinity, S.NaN):
  53. return S.NaN
  54. return self
  55. return Number.__add__(self, other)
  56. __radd__ = __add__
  57. @_sympifyit("other", NotImplemented)
  58. def __sub__(self, other):
  59. if isinstance(other, Number) and global_parameters.evaluate:
  60. if other is S.Infinity:
  61. return S.NegativeInfinity
  62. if other is S.NegativeInfinity:
  63. return S.Infinity
  64. if other in (S.IntInfinity, S.NaN):
  65. return S.NaN
  66. return self
  67. return Number.__sub__(self, other)
  68. @_sympifyit("other", NotImplemented)
  69. def __rsub__(self, other):
  70. return (-self).__add__(other)
  71. @_sympifyit("other", NotImplemented)
  72. def __mul__(self, other):
  73. if isinstance(other, Number) and global_parameters.evaluate:
  74. if other.is_zero or other is S.NaN:
  75. return S.NaN
  76. if other.is_extended_positive:
  77. return self
  78. return S.NegativeIntInfinity
  79. return Number.__mul__(self, other)
  80. __rmul__ = __mul__
  81. @_sympifyit("other", NotImplemented)
  82. def __truediv__(self, other):
  83. if isinstance(other, Number) and global_parameters.evaluate:
  84. if other in (
  85. S.Infinity,
  86. S.IntInfinity,
  87. S.NegativeInfinity,
  88. S.NegativeIntInfinity,
  89. S.NaN,
  90. ):
  91. return S.NaN
  92. if other.is_extended_nonnegative:
  93. return S.Infinity # truediv produces float
  94. return S.NegativeInfinity # truediv produces float
  95. return Number.__truediv__(self, other)
  96. def __abs__(self):
  97. return S.IntInfinity
  98. def __neg__(self):
  99. return S.NegativeIntInfinity
  100. def _eval_power(self, expt):
  101. if expt.is_extended_positive:
  102. return S.IntInfinity
  103. if expt.is_extended_negative:
  104. return S.Zero
  105. if expt is S.NaN:
  106. return S.NaN
  107. if expt is S.ComplexInfinity:
  108. return S.NaN
  109. if expt.is_extended_real is False and expt.is_number:
  110. from sympy.functions.elementary.complexes import re
  111. expt_real = re(expt)
  112. if expt_real.is_positive:
  113. return S.ComplexInfinity
  114. if expt_real.is_negative:
  115. return S.Zero
  116. if expt_real.is_zero:
  117. return S.NaN
  118. return self ** expt.evalf()
  119. def _as_mpf_val(self, prec):
  120. return mlib.finf
  121. def __hash__(self):
  122. return super().__hash__()
  123. def __eq__(self, other):
  124. return other is S.IntInfinity
  125. def __ne__(self, other):
  126. return other is not S.IntInfinity
  127. def __gt__(self, other):
  128. if other is S.Infinity:
  129. return sympy.false # sympy.oo > int_oo
  130. elif other is S.IntInfinity:
  131. return sympy.false # consistency with sympy.oo
  132. else:
  133. return sympy.true
  134. def __ge__(self, other):
  135. if other is S.Infinity:
  136. return sympy.false # sympy.oo > int_oo
  137. elif other is S.IntInfinity:
  138. return sympy.true # consistency with sympy.oo
  139. else:
  140. return sympy.true
  141. def __lt__(self, other):
  142. if other is S.Infinity:
  143. return sympy.true # sympy.oo > int_oo
  144. elif other is S.IntInfinity:
  145. return sympy.false # consistency with sympy.oo
  146. else:
  147. return sympy.false
  148. def __le__(self, other):
  149. if other is S.Infinity:
  150. return sympy.true # sympy.oo > int_oo
  151. elif other is S.IntInfinity:
  152. return sympy.true # consistency with sympy.oo
  153. else:
  154. return sympy.false
  155. @_sympifyit("other", NotImplemented)
  156. def __mod__(self, other):
  157. if not isinstance(other, Expr):
  158. return NotImplemented
  159. return S.NaN
  160. __rmod__ = __mod__
  161. def floor(self):
  162. return self
  163. def ceiling(self):
  164. return self
  165. int_oo = S.IntInfinity
  166. def is_infinite(expr) -> bool:
  167. """Check if an expression is any type of infinity (positive or negative).
  168. This handles both sympy's built-in infinities (oo, -oo) and PyTorch's
  169. integer infinities (int_oo, -int_oo).
  170. Note: We cannot rely on sympy's is_finite property because IntInfinity
  171. and NegativeIntInfinity have is_integer=True, which implies is_finite=True
  172. in sympy's assumption system.
  173. """
  174. return expr in (
  175. S.Infinity,
  176. S.NegativeInfinity,
  177. S.IntInfinity,
  178. S.NegativeIntInfinity,
  179. )
  180. # pyrefly: ignore [invalid-inheritance]
  181. class NegativeIntInfinity(Number, metaclass=Singleton):
  182. """Negative integer infinite quantity.
  183. NegativeInfinity is a singleton, and can be accessed
  184. by ``S.NegativeInfinity``.
  185. See Also
  186. ========
  187. IntInfinity
  188. """
  189. # Ensure we get dispatched to before plain numbers
  190. _op_priority = 100.0
  191. is_integer = True
  192. is_extended_real = True
  193. is_commutative = True
  194. is_comparable = True
  195. is_extended_negative = True
  196. is_number = True
  197. is_prime = False
  198. __slots__ = ()
  199. def __new__(cls):
  200. return AtomicExpr.__new__(cls)
  201. def _eval_subs(self, old, new):
  202. if self == old:
  203. return new
  204. def _sympystr(self, printer) -> str:
  205. return "-int_oo"
  206. """
  207. def _eval_evalf(self, prec=None):
  208. return Float('-inf')
  209. def evalf(self, prec=None, **options):
  210. return self._eval_evalf(prec)
  211. """
  212. @_sympifyit("other", NotImplemented)
  213. def __add__(self, other):
  214. if isinstance(other, Number) and global_parameters.evaluate:
  215. if other is S.Infinity:
  216. return S.Infinity
  217. if other in (S.IntInfinity, S.NaN):
  218. return S.NaN
  219. return self
  220. return Number.__add__(self, other)
  221. __radd__ = __add__
  222. @_sympifyit("other", NotImplemented)
  223. def __sub__(self, other):
  224. if isinstance(other, Number) and global_parameters.evaluate:
  225. if other is S.NegativeInfinity:
  226. return S.Infinity
  227. if other in (S.NegativeIntInfinity, S.NaN):
  228. return S.NaN
  229. return self
  230. return Number.__sub__(self, other)
  231. @_sympifyit("other", NotImplemented)
  232. def __rsub__(self, other):
  233. return (-self).__add__(other)
  234. @_sympifyit("other", NotImplemented)
  235. def __mul__(self, other):
  236. if isinstance(other, Number) and global_parameters.evaluate:
  237. if other.is_zero or other is S.NaN:
  238. return S.NaN
  239. if other.is_extended_positive:
  240. return self
  241. return S.IntInfinity
  242. return Number.__mul__(self, other)
  243. __rmul__ = __mul__
  244. @_sympifyit("other", NotImplemented)
  245. def __truediv__(self, other):
  246. if isinstance(other, Number) and global_parameters.evaluate:
  247. if other in (
  248. S.Infinity,
  249. S.IntInfinity,
  250. S.NegativeInfinity,
  251. S.NegativeIntInfinity,
  252. S.NaN,
  253. ):
  254. return S.NaN
  255. if other.is_extended_nonnegative:
  256. return self
  257. return S.Infinity # truediv returns float
  258. return Number.__truediv__(self, other)
  259. def __abs__(self):
  260. return S.IntInfinity
  261. def __neg__(self):
  262. return S.IntInfinity
  263. def _eval_power(self, expt):
  264. if expt.is_number:
  265. if expt in (
  266. S.NaN,
  267. S.Infinity,
  268. S.NegativeInfinity,
  269. S.IntInfinity,
  270. S.NegativeIntInfinity,
  271. ):
  272. return S.NaN
  273. if isinstance(expt, sympy.Integer) and expt.is_extended_positive:
  274. if expt.is_odd:
  275. return S.NegativeIntInfinity
  276. else:
  277. return S.IntInfinity
  278. inf_part = S.IntInfinity**expt
  279. s_part = S.NegativeOne**expt
  280. if inf_part == 0 and s_part.is_finite:
  281. return inf_part
  282. if (
  283. inf_part is S.ComplexInfinity
  284. and s_part.is_finite
  285. and not s_part.is_zero
  286. ):
  287. return S.ComplexInfinity
  288. return s_part * inf_part
  289. def _as_mpf_val(self, prec):
  290. return mlib.fninf
  291. def __hash__(self):
  292. return super().__hash__()
  293. def __eq__(self, other):
  294. return other is S.NegativeIntInfinity
  295. def __ne__(self, other):
  296. return other is not S.NegativeIntInfinity
  297. def __gt__(self, other):
  298. if other is S.NegativeInfinity:
  299. return sympy.true # -sympy.oo < -int_oo
  300. elif other is S.NegativeIntInfinity:
  301. return sympy.false # consistency with sympy.oo
  302. else:
  303. return sympy.false
  304. def __ge__(self, other):
  305. if other is S.NegativeInfinity:
  306. return sympy.true # -sympy.oo < -int_oo
  307. elif other is S.NegativeIntInfinity:
  308. return sympy.true # consistency with sympy.oo
  309. else:
  310. return sympy.false
  311. def __lt__(self, other):
  312. if other is S.NegativeInfinity:
  313. return sympy.false # -sympy.oo < -int_oo
  314. elif other is S.NegativeIntInfinity:
  315. return sympy.false # consistency with sympy.oo
  316. else:
  317. return sympy.true
  318. def __le__(self, other):
  319. if other is S.NegativeInfinity:
  320. return sympy.false # -sympy.oo < -int_oo
  321. elif other is S.NegativeIntInfinity:
  322. return sympy.true # consistency with sympy.oo
  323. else:
  324. return sympy.true
  325. @_sympifyit("other", NotImplemented)
  326. def __mod__(self, other):
  327. if not isinstance(other, Expr):
  328. return NotImplemented
  329. return S.NaN
  330. __rmod__ = __mod__
  331. def floor(self):
  332. return self
  333. def ceiling(self):
  334. return self
  335. def as_powers_dict(self):
  336. return {S.NegativeOne: 1, S.IntInfinity: 1}