pythonmpq.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. """
  2. PythonMPQ: Rational number type based on Python integers.
  3. This class is intended as a pure Python fallback for when gmpy2 is not
  4. installed. If gmpy2 is installed then its mpq type will be used instead. The
  5. mpq type is around 20x faster. We could just use the stdlib Fraction class
  6. here but that is slower:
  7. from fractions import Fraction
  8. from sympy.external.pythonmpq import PythonMPQ
  9. nums = range(1000)
  10. dens = range(5, 1005)
  11. rats = [Fraction(n, d) for n, d in zip(nums, dens)]
  12. sum(rats) # <--- 24 milliseconds
  13. rats = [PythonMPQ(n, d) for n, d in zip(nums, dens)]
  14. sum(rats) # <--- 7 milliseconds
  15. Both mpq and Fraction have some awkward features like the behaviour of
  16. division with // and %:
  17. >>> from fractions import Fraction
  18. >>> Fraction(2, 3) % Fraction(1, 4)
  19. 1/6
  20. For the QQ domain we do not want this behaviour because there should be no
  21. remainder when dividing rational numbers. SymPy does not make use of this
  22. aspect of mpq when gmpy2 is installed. Since this class is a fallback for that
  23. case we do not bother implementing e.g. __mod__ so that we can be sure we
  24. are not using it when gmpy2 is installed either.
  25. """
  26. from __future__ import annotations
  27. import operator
  28. from math import gcd
  29. from decimal import Decimal
  30. from fractions import Fraction
  31. import sys
  32. from typing import Type
  33. # Used for __hash__
  34. _PyHASH_MODULUS = sys.hash_info.modulus
  35. _PyHASH_INF = sys.hash_info.inf
  36. class PythonMPQ:
  37. """Rational number implementation that is intended to be compatible with
  38. gmpy2's mpq.
  39. Also slightly faster than fractions.Fraction.
  40. PythonMPQ should be treated as immutable although no effort is made to
  41. prevent mutation (since that might slow down calculations).
  42. """
  43. __slots__ = ('numerator', 'denominator')
  44. def __new__(cls, numerator, denominator=None):
  45. """Construct PythonMPQ with gcd computation and checks"""
  46. if denominator is not None:
  47. #
  48. # PythonMPQ(n, d): require n and d to be int and d != 0
  49. #
  50. if isinstance(numerator, int) and isinstance(denominator, int):
  51. # This is the slow part:
  52. divisor = gcd(numerator, denominator)
  53. numerator //= divisor
  54. denominator //= divisor
  55. return cls._new_check(numerator, denominator)
  56. else:
  57. #
  58. # PythonMPQ(q)
  59. #
  60. # Here q can be PythonMPQ, int, Decimal, float, Fraction or str
  61. #
  62. if isinstance(numerator, int):
  63. return cls._new(numerator, 1)
  64. elif isinstance(numerator, PythonMPQ):
  65. return cls._new(numerator.numerator, numerator.denominator)
  66. # Let Fraction handle Decimal/float conversion and str parsing
  67. if isinstance(numerator, (Decimal, float, str)):
  68. numerator = Fraction(numerator)
  69. if isinstance(numerator, Fraction):
  70. return cls._new(numerator.numerator, numerator.denominator)
  71. #
  72. # Reject everything else. This is more strict than mpq which allows
  73. # things like mpq(Fraction, Fraction) or mpq(Decimal, any). The mpq
  74. # behaviour is somewhat inconsistent so we choose to accept only a
  75. # more strict subset of what mpq allows.
  76. #
  77. raise TypeError("PythonMPQ() requires numeric or string argument")
  78. @classmethod
  79. def _new_check(cls, numerator, denominator):
  80. """Construct PythonMPQ, check divide by zero and canonicalize signs"""
  81. if not denominator:
  82. raise ZeroDivisionError(f'Zero divisor {numerator}/{denominator}')
  83. elif denominator < 0:
  84. numerator = -numerator
  85. denominator = -denominator
  86. return cls._new(numerator, denominator)
  87. @classmethod
  88. def _new(cls, numerator, denominator):
  89. """Construct PythonMPQ efficiently (no checks)"""
  90. obj = super().__new__(cls)
  91. obj.numerator = numerator
  92. obj.denominator = denominator
  93. return obj
  94. def __int__(self):
  95. """Convert to int (truncates towards zero)"""
  96. p, q = self.numerator, self.denominator
  97. if p < 0:
  98. return -(-p//q)
  99. return p//q
  100. def __float__(self):
  101. """Convert to float (approximately)"""
  102. return self.numerator / self.denominator
  103. def __bool__(self):
  104. """True/False if nonzero/zero"""
  105. return bool(self.numerator)
  106. def __eq__(self, other):
  107. """Compare equal with PythonMPQ, int, float, Decimal or Fraction"""
  108. if isinstance(other, PythonMPQ):
  109. return (self.numerator == other.numerator
  110. and self.denominator == other.denominator)
  111. elif isinstance(other, self._compatible_types):
  112. return self.__eq__(PythonMPQ(other))
  113. else:
  114. return NotImplemented
  115. def __hash__(self):
  116. """hash - same as mpq/Fraction"""
  117. try:
  118. dinv = pow(self.denominator, -1, _PyHASH_MODULUS)
  119. except ValueError:
  120. hash_ = _PyHASH_INF
  121. else:
  122. hash_ = hash(hash(abs(self.numerator)) * dinv)
  123. result = hash_ if self.numerator >= 0 else -hash_
  124. return -2 if result == -1 else result
  125. def __reduce__(self):
  126. """Deconstruct for pickling"""
  127. return type(self), (self.numerator, self.denominator)
  128. def __str__(self):
  129. """Convert to string"""
  130. if self.denominator != 1:
  131. return f"{self.numerator}/{self.denominator}"
  132. else:
  133. return f"{self.numerator}"
  134. def __repr__(self):
  135. """Convert to string"""
  136. return f"MPQ({self.numerator},{self.denominator})"
  137. def _cmp(self, other, op):
  138. """Helper for lt/le/gt/ge"""
  139. if not isinstance(other, self._compatible_types):
  140. return NotImplemented
  141. lhs = self.numerator * other.denominator
  142. rhs = other.numerator * self.denominator
  143. return op(lhs, rhs)
  144. def __lt__(self, other):
  145. """self < other"""
  146. return self._cmp(other, operator.lt)
  147. def __le__(self, other):
  148. """self <= other"""
  149. return self._cmp(other, operator.le)
  150. def __gt__(self, other):
  151. """self > other"""
  152. return self._cmp(other, operator.gt)
  153. def __ge__(self, other):
  154. """self >= other"""
  155. return self._cmp(other, operator.ge)
  156. def __abs__(self):
  157. """abs(q)"""
  158. return self._new(abs(self.numerator), self.denominator)
  159. def __pos__(self):
  160. """+q"""
  161. return self
  162. def __neg__(self):
  163. """-q"""
  164. return self._new(-self.numerator, self.denominator)
  165. def __add__(self, other):
  166. """q1 + q2"""
  167. if isinstance(other, PythonMPQ):
  168. #
  169. # This is much faster than the naive method used in the stdlib
  170. # fractions module. Not sure where this method comes from
  171. # though...
  172. #
  173. # Compare timings for something like:
  174. # nums = range(1000)
  175. # rats = [PythonMPQ(n, d) for n, d in zip(nums[:-5], nums[5:])]
  176. # sum(rats) # <-- time this
  177. #
  178. ap, aq = self.numerator, self.denominator
  179. bp, bq = other.numerator, other.denominator
  180. g = gcd(aq, bq)
  181. if g == 1:
  182. p = ap*bq + aq*bp
  183. q = bq*aq
  184. else:
  185. q1, q2 = aq//g, bq//g
  186. p, q = ap*q2 + bp*q1, q1*q2
  187. g2 = gcd(p, g)
  188. p, q = (p // g2), q * (g // g2)
  189. elif isinstance(other, int):
  190. p = self.numerator + self.denominator * other
  191. q = self.denominator
  192. else:
  193. return NotImplemented
  194. return self._new(p, q)
  195. def __radd__(self, other):
  196. """z1 + q2"""
  197. if isinstance(other, int):
  198. p = self.numerator + self.denominator * other
  199. q = self.denominator
  200. return self._new(p, q)
  201. else:
  202. return NotImplemented
  203. def __sub__(self ,other):
  204. """q1 - q2"""
  205. if isinstance(other, PythonMPQ):
  206. ap, aq = self.numerator, self.denominator
  207. bp, bq = other.numerator, other.denominator
  208. g = gcd(aq, bq)
  209. if g == 1:
  210. p = ap*bq - aq*bp
  211. q = bq*aq
  212. else:
  213. q1, q2 = aq//g, bq//g
  214. p, q = ap*q2 - bp*q1, q1*q2
  215. g2 = gcd(p, g)
  216. p, q = (p // g2), q * (g // g2)
  217. elif isinstance(other, int):
  218. p = self.numerator - self.denominator*other
  219. q = self.denominator
  220. else:
  221. return NotImplemented
  222. return self._new(p, q)
  223. def __rsub__(self, other):
  224. """z1 - q2"""
  225. if isinstance(other, int):
  226. p = self.denominator * other - self.numerator
  227. q = self.denominator
  228. return self._new(p, q)
  229. else:
  230. return NotImplemented
  231. def __mul__(self, other):
  232. """q1 * q2"""
  233. if isinstance(other, PythonMPQ):
  234. ap, aq = self.numerator, self.denominator
  235. bp, bq = other.numerator, other.denominator
  236. x1 = gcd(ap, bq)
  237. x2 = gcd(bp, aq)
  238. p, q = ((ap//x1)*(bp//x2), (aq//x2)*(bq//x1))
  239. elif isinstance(other, int):
  240. x = gcd(other, self.denominator)
  241. p = self.numerator*(other//x)
  242. q = self.denominator//x
  243. else:
  244. return NotImplemented
  245. return self._new(p, q)
  246. def __rmul__(self, other):
  247. """z1 * q2"""
  248. if isinstance(other, int):
  249. x = gcd(self.denominator, other)
  250. p = self.numerator*(other//x)
  251. q = self.denominator//x
  252. return self._new(p, q)
  253. else:
  254. return NotImplemented
  255. def __pow__(self, exp):
  256. """q ** z"""
  257. p, q = self.numerator, self.denominator
  258. if exp < 0:
  259. p, q, exp = q, p, -exp
  260. return self._new_check(p**exp, q**exp)
  261. def __truediv__(self, other):
  262. """q1 / q2"""
  263. if isinstance(other, PythonMPQ):
  264. ap, aq = self.numerator, self.denominator
  265. bp, bq = other.numerator, other.denominator
  266. x1 = gcd(ap, bp)
  267. x2 = gcd(bq, aq)
  268. p, q = ((ap//x1)*(bq//x2), (aq//x2)*(bp//x1))
  269. elif isinstance(other, int):
  270. x = gcd(other, self.numerator)
  271. p = self.numerator//x
  272. q = self.denominator*(other//x)
  273. else:
  274. return NotImplemented
  275. return self._new_check(p, q)
  276. def __rtruediv__(self, other):
  277. """z / q"""
  278. if isinstance(other, int):
  279. x = gcd(self.numerator, other)
  280. p = self.denominator*(other//x)
  281. q = self.numerator//x
  282. return self._new_check(p, q)
  283. else:
  284. return NotImplemented
  285. _compatible_types: tuple[Type, ...] = ()
  286. #
  287. # These are the types that PythonMPQ will interoperate with for operations
  288. # and comparisons such as ==, + etc. We define this down here so that we can
  289. # include PythonMPQ in the list as well.
  290. #
  291. PythonMPQ._compatible_types = (PythonMPQ, int, Decimal, Fraction)