partitions_.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. from mpmath.libmp import (fzero, from_int, from_rational,
  2. fone, fhalf, bitcount, to_int, mpf_mul, mpf_div, mpf_sub,
  3. mpf_add, mpf_sqrt, mpf_pi, mpf_cosh_sinh, mpf_cos, mpf_sin)
  4. from .residue_ntheory import _sqrt_mod_prime_power, is_quad_residue
  5. from sympy.utilities.decorator import deprecated
  6. from sympy.utilities.memoization import recurrence_memo
  7. import math
  8. from itertools import count
  9. def _pre():
  10. maxn = 10**5
  11. global _factor, _totient
  12. _factor = [0]*maxn
  13. _totient = [1]*maxn
  14. lim = int(maxn**0.5) + 5
  15. for i in range(2, lim):
  16. if _factor[i] == 0:
  17. for j in range(i*i, maxn, i):
  18. if _factor[j] == 0:
  19. _factor[j] = i
  20. for i in range(2, maxn):
  21. if _factor[i] == 0:
  22. _factor[i] = i
  23. _totient[i] = i-1
  24. continue
  25. x = _factor[i]
  26. y = i//x
  27. if y % x == 0:
  28. _totient[i] = _totient[y]*x
  29. else:
  30. _totient[i] = _totient[y]*(x - 1)
  31. def _a(n, k, prec):
  32. """ Compute the inner sum in HRR formula [1]_
  33. References
  34. ==========
  35. .. [1] https://msp.org/pjm/1956/6-1/pjm-v6-n1-p18-p.pdf
  36. """
  37. if k == 1:
  38. return fone
  39. k1 = k
  40. e = 0
  41. p = _factor[k]
  42. while k1 % p == 0:
  43. k1 //= p
  44. e += 1
  45. k2 = k//k1 # k2 = p^e
  46. v = 1 - 24*n
  47. pi = mpf_pi(prec)
  48. if k1 == 1:
  49. # k = p^e
  50. if p == 2:
  51. mod = 8*k
  52. v = mod + v % mod
  53. v = (v*pow(9, k - 1, mod)) % mod
  54. m = _sqrt_mod_prime_power(v, 2, e + 3)[0]
  55. arg = mpf_div(mpf_mul(
  56. from_int(4*m), pi, prec), from_int(mod), prec)
  57. return mpf_mul(mpf_mul(
  58. from_int((-1)**e*(2 - (m % 4))),
  59. mpf_sqrt(from_int(k), prec), prec),
  60. mpf_sin(arg, prec), prec)
  61. if p == 3:
  62. mod = 3*k
  63. v = mod + v % mod
  64. if e > 1:
  65. v = (v*pow(64, k//3 - 1, mod)) % mod
  66. m = _sqrt_mod_prime_power(v, 3, e + 1)[0]
  67. arg = mpf_div(mpf_mul(from_int(4*m), pi, prec),
  68. from_int(mod), prec)
  69. return mpf_mul(mpf_mul(
  70. from_int(2*(-1)**(e + 1)*(3 - 2*(m % 3))),
  71. mpf_sqrt(from_int(k//3), prec), prec),
  72. mpf_sin(arg, prec), prec)
  73. v = k + v % k
  74. jacobi3 = -1 if k % 12 in [5, 7] else 1
  75. if v % p == 0:
  76. if e == 1:
  77. return mpf_mul(
  78. from_int(jacobi3),
  79. mpf_sqrt(from_int(k), prec), prec)
  80. return fzero
  81. if not is_quad_residue(v, p):
  82. return fzero
  83. _phi = p**(e - 1)*(p - 1)
  84. v = (v*pow(576, _phi - 1, k))
  85. m = _sqrt_mod_prime_power(v, p, e)[0]
  86. arg = mpf_div(
  87. mpf_mul(from_int(4*m), pi, prec),
  88. from_int(k), prec)
  89. return mpf_mul(mpf_mul(
  90. from_int(2*jacobi3),
  91. mpf_sqrt(from_int(k), prec), prec),
  92. mpf_cos(arg, prec), prec)
  93. if p != 2 or e >= 3:
  94. d1, d2 = math.gcd(k1, 24), math.gcd(k2, 24)
  95. e = 24//(d1*d2)
  96. n1 = ((d2*e*n + (k2**2 - 1)//d1)*
  97. pow(e*k2*k2*d2, _totient[k1] - 1, k1)) % k1
  98. n2 = ((d1*e*n + (k1**2 - 1)//d2)*
  99. pow(e*k1*k1*d1, _totient[k2] - 1, k2)) % k2
  100. return mpf_mul(_a(n1, k1, prec), _a(n2, k2, prec), prec)
  101. if e == 2:
  102. n1 = ((8*n + 5)*pow(128, _totient[k1] - 1, k1)) % k1
  103. n2 = (4 + ((n - 2 - (k1**2 - 1)//8)*(k1**2)) % 4) % 4
  104. return mpf_mul(mpf_mul(
  105. from_int(-1),
  106. _a(n1, k1, prec), prec),
  107. _a(n2, k2, prec))
  108. n1 = ((8*n + 1)*pow(32, _totient[k1] - 1, k1)) % k1
  109. n2 = (2 + (n - (k1**2 - 1)//8) % 2) % 2
  110. return mpf_mul(_a(n1, k1, prec), _a(n2, k2, prec), prec)
  111. def _d(n, j, prec, sq23pi, sqrt8):
  112. """
  113. Compute the sinh term in the outer sum of the HRR formula.
  114. The constants sqrt(2/3*pi) and sqrt(8) must be precomputed.
  115. """
  116. j = from_int(j)
  117. pi = mpf_pi(prec)
  118. a = mpf_div(sq23pi, j, prec)
  119. b = mpf_sub(from_int(n), from_rational(1, 24, prec), prec)
  120. c = mpf_sqrt(b, prec)
  121. ch, sh = mpf_cosh_sinh(mpf_mul(a, c), prec)
  122. D = mpf_div(
  123. mpf_sqrt(j, prec),
  124. mpf_mul(mpf_mul(sqrt8, b), pi), prec)
  125. E = mpf_sub(mpf_mul(a, ch), mpf_div(sh, c, prec), prec)
  126. return mpf_mul(D, E)
  127. @recurrence_memo([1, 1])
  128. def _partition_rec(n: int, prev) -> int:
  129. """ Calculate the partition function P(n)
  130. Parameters
  131. ==========
  132. n : int
  133. nonnegative integer
  134. """
  135. v = 0
  136. penta = 0 # pentagonal number: 1, 5, 12, ...
  137. for i in count():
  138. penta += 3*i + 1
  139. np = n - penta
  140. if np < 0:
  141. break
  142. s = prev[np]
  143. np -= i + 1
  144. # np = n - gp where gp = generalized pentagonal: 2, 7, 15, ...
  145. if 0 <= np:
  146. s += prev[np]
  147. v += -s if i % 2 else s
  148. return v
  149. def _partition(n: int) -> int:
  150. """ Calculate the partition function P(n)
  151. Parameters
  152. ==========
  153. n : int
  154. """
  155. if n < 0:
  156. return 0
  157. if (n <= 200_000 and n - _partition_rec.cache_length() < 70 or
  158. _partition_rec.cache_length() == 2 and n < 14_400):
  159. # There will be 2*10**5 elements created here
  160. # and n elements created by partition, so in case we
  161. # are going to be working with small n, we just
  162. # use partition to calculate (and cache) the values
  163. # since lookup is used there while summation, using
  164. # _factor and _totient, will be used below. But we
  165. # only do so if n is relatively close to the length
  166. # of the cache since doing 1 calculation here is about
  167. # the same as adding 70 elements to the cache. In addition,
  168. # the startup here costs about the same as calculating the first
  169. # 14,400 values via partition, so we delay startup here unless n
  170. # is smaller than that.
  171. return _partition_rec(n)
  172. if '_factor' not in globals():
  173. _pre()
  174. # Estimate number of bits in p(n). This formula could be tidied
  175. pbits = int((
  176. math.pi*(2*n/3.)**0.5 -
  177. math.log(4*n))/math.log(10) + 1) * \
  178. math.log2(10)
  179. prec = p = int(pbits*1.1 + 100)
  180. # find the number of terms needed so rounded sum will be accurate
  181. # using Rademacher's bound M(n, N) for the remainder after a partial
  182. # sum of N terms (https://arxiv.org/pdf/1205.5991.pdf, (1.8))
  183. c1 = 44*math.pi**2/(225*math.sqrt(3))
  184. c2 = math.pi*math.sqrt(2)/75
  185. c3 = math.pi*math.sqrt(2/3)
  186. def _M(n, N):
  187. sqrt = math.sqrt
  188. return c1/sqrt(N) + c2*sqrt(N/(n - 1))*math.sinh(c3*sqrt(n)/N)
  189. big = max(9, math.ceil(n**0.5)) # should be too large (for n > 65, ceil should work)
  190. assert _M(n, big) < 0.5 # else double big until too large
  191. while big > 40 and _M(n, big) < 0.5:
  192. big //= 2
  193. small = big
  194. big = small*2
  195. while big - small > 1:
  196. N = (big + small)//2
  197. if (er := _M(n, N)) < 0.5:
  198. big = N
  199. elif er >= 0.5:
  200. small = N
  201. M = big # done with function M; now have value
  202. # sanity check for expected size of answer
  203. if M > 10**5: # i.e. M > maxn
  204. raise ValueError("Input too big") # i.e. n > 149832547102
  205. # calculate it
  206. s = fzero
  207. sq23pi = mpf_mul(mpf_sqrt(from_rational(2, 3, p), p), mpf_pi(p), p)
  208. sqrt8 = mpf_sqrt(from_int(8), p)
  209. for q in range(1, M):
  210. a = _a(n, q, p)
  211. d = _d(n, q, p, sq23pi, sqrt8)
  212. s = mpf_add(s, mpf_mul(a, d), prec)
  213. # On average, the terms decrease rapidly in magnitude.
  214. # Dynamically reducing the precision greatly improves
  215. # performance.
  216. p = bitcount(abs(to_int(d))) + 50
  217. return int(to_int(mpf_add(s, fhalf, prec)))
  218. @deprecated("""\
  219. The `sympy.ntheory.partitions_.npartitions` has been moved to `sympy.functions.combinatorial.numbers.partition`.""",
  220. deprecated_since_version="1.13",
  221. active_deprecations_target='deprecated-ntheory-symbolic-functions')
  222. def npartitions(n, verbose=False):
  223. """
  224. Calculate the partition function P(n), i.e. the number of ways that
  225. n can be written as a sum of positive integers.
  226. .. deprecated:: 1.13
  227. The ``npartitions`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.partition`
  228. instead. See its documentation for more information. See
  229. :ref:`deprecated-ntheory-symbolic-functions` for details.
  230. P(n) is computed using the Hardy-Ramanujan-Rademacher formula [1]_.
  231. The correctness of this implementation has been tested through $10^{10}$.
  232. Examples
  233. ========
  234. >>> from sympy.functions.combinatorial.numbers import partition
  235. >>> partition(25)
  236. 1958
  237. References
  238. ==========
  239. .. [1] https://mathworld.wolfram.com/PartitionFunctionP.html
  240. """
  241. from sympy.functions.combinatorial.numbers import partition as func_partition
  242. return func_partition(n)