convolutions.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597
  1. """
  2. Convolution (using **FFT**, **NTT**, **FWHT**), Subset Convolution,
  3. Covering Product, Intersecting Product
  4. """
  5. from sympy.core import S, sympify, Rational
  6. from sympy.core.function import expand_mul
  7. from sympy.discrete.transforms import (
  8. fft, ifft, ntt, intt, fwht, ifwht,
  9. mobius_transform, inverse_mobius_transform)
  10. from sympy.external.gmpy import MPZ, lcm
  11. from sympy.utilities.iterables import iterable
  12. from sympy.utilities.misc import as_int
  13. def convolution(a, b, cycle=0, dps=None, prime=None, dyadic=None, subset=None):
  14. """
  15. Performs convolution by determining the type of desired
  16. convolution using hints.
  17. Exactly one of ``dps``, ``prime``, ``dyadic``, ``subset`` arguments
  18. should be specified explicitly for identifying the type of convolution,
  19. and the argument ``cycle`` can be specified optionally.
  20. For the default arguments, linear convolution is performed using **FFT**.
  21. Parameters
  22. ==========
  23. a, b : iterables
  24. The sequences for which convolution is performed.
  25. cycle : Integer
  26. Specifies the length for doing cyclic convolution.
  27. dps : Integer
  28. Specifies the number of decimal digits for precision for
  29. performing **FFT** on the sequence.
  30. prime : Integer
  31. Prime modulus of the form `(m 2^k + 1)` to be used for
  32. performing **NTT** on the sequence.
  33. dyadic : bool
  34. Identifies the convolution type as dyadic (*bitwise-XOR*)
  35. convolution, which is performed using **FWHT**.
  36. subset : bool
  37. Identifies the convolution type as subset convolution.
  38. Examples
  39. ========
  40. >>> from sympy import convolution, symbols, S, I
  41. >>> u, v, w, x, y, z = symbols('u v w x y z')
  42. >>> convolution([1 + 2*I, 4 + 3*I], [S(5)/4, 6], dps=3)
  43. [1.25 + 2.5*I, 11.0 + 15.8*I, 24.0 + 18.0*I]
  44. >>> convolution([1, 2, 3], [4, 5, 6], cycle=3)
  45. [31, 31, 28]
  46. >>> convolution([111, 777], [888, 444], prime=19*2**10 + 1)
  47. [1283, 19351, 14219]
  48. >>> convolution([111, 777], [888, 444], prime=19*2**10 + 1, cycle=2)
  49. [15502, 19351]
  50. >>> convolution([u, v], [x, y, z], dyadic=True)
  51. [u*x + v*y, u*y + v*x, u*z, v*z]
  52. >>> convolution([u, v], [x, y, z], dyadic=True, cycle=2)
  53. [u*x + u*z + v*y, u*y + v*x + v*z]
  54. >>> convolution([u, v, w], [x, y, z], subset=True)
  55. [u*x, u*y + v*x, u*z + w*x, v*z + w*y]
  56. >>> convolution([u, v, w], [x, y, z], subset=True, cycle=3)
  57. [u*x + v*z + w*y, u*y + v*x, u*z + w*x]
  58. """
  59. c = as_int(cycle)
  60. if c < 0:
  61. raise ValueError("The length for cyclic convolution "
  62. "must be non-negative")
  63. dyadic = True if dyadic else None
  64. subset = True if subset else None
  65. if sum(x is not None for x in (prime, dps, dyadic, subset)) > 1:
  66. raise TypeError("Ambiguity in determining the type of convolution")
  67. if prime is not None:
  68. ls = convolution_ntt(a, b, prime=prime)
  69. return ls if not c else [sum(ls[i::c]) % prime for i in range(c)]
  70. if dyadic:
  71. ls = convolution_fwht(a, b)
  72. elif subset:
  73. ls = convolution_subset(a, b)
  74. else:
  75. def loop(a):
  76. dens = []
  77. for i in a:
  78. if isinstance(i, Rational) and i.q - 1:
  79. dens.append(i.q)
  80. elif not isinstance(i, int):
  81. return
  82. if dens:
  83. l = lcm(*dens)
  84. return [i*l if type(i) is int else i.p*(l//i.q) for i in a], l
  85. # no lcm of den to deal with
  86. return a, 1
  87. ls = None
  88. da = loop(a)
  89. if da is not None:
  90. db = loop(b)
  91. if db is not None:
  92. (ia, ma), (ib, mb) = da, db
  93. den = ma*mb
  94. ls = convolution_int(ia, ib)
  95. if den != 1:
  96. ls = [Rational(i, den) for i in ls]
  97. if ls is None:
  98. ls = convolution_fft(a, b, dps)
  99. return ls if not c else [sum(ls[i::c]) for i in range(c)]
  100. #----------------------------------------------------------------------------#
  101. # #
  102. # Convolution for Complex domain #
  103. # #
  104. #----------------------------------------------------------------------------#
  105. def convolution_fft(a, b, dps=None):
  106. """
  107. Performs linear convolution using Fast Fourier Transform.
  108. Parameters
  109. ==========
  110. a, b : iterables
  111. The sequences for which convolution is performed.
  112. dps : Integer
  113. Specifies the number of decimal digits for precision.
  114. Examples
  115. ========
  116. >>> from sympy import S, I
  117. >>> from sympy.discrete.convolutions import convolution_fft
  118. >>> convolution_fft([2, 3], [4, 5])
  119. [8, 22, 15]
  120. >>> convolution_fft([2, 5], [6, 7, 3])
  121. [12, 44, 41, 15]
  122. >>> convolution_fft([1 + 2*I, 4 + 3*I], [S(5)/4, 6])
  123. [5/4 + 5*I/2, 11 + 63*I/4, 24 + 18*I]
  124. References
  125. ==========
  126. .. [1] https://en.wikipedia.org/wiki/Convolution_theorem
  127. .. [2] https://en.wikipedia.org/wiki/Discrete_Fourier_transform_(general%29
  128. """
  129. a, b = a[:], b[:]
  130. n = m = len(a) + len(b) - 1 # convolution size
  131. if n > 0 and n&(n - 1): # not a power of 2
  132. n = 2**n.bit_length()
  133. # padding with zeros
  134. a += [S.Zero]*(n - len(a))
  135. b += [S.Zero]*(n - len(b))
  136. a, b = fft(a, dps), fft(b, dps)
  137. a = [expand_mul(x*y) for x, y in zip(a, b)]
  138. a = ifft(a, dps)[:m]
  139. return a
  140. #----------------------------------------------------------------------------#
  141. # #
  142. # Convolution for GF(p) #
  143. # #
  144. #----------------------------------------------------------------------------#
  145. def convolution_ntt(a, b, prime):
  146. """
  147. Performs linear convolution using Number Theoretic Transform.
  148. Parameters
  149. ==========
  150. a, b : iterables
  151. The sequences for which convolution is performed.
  152. prime : Integer
  153. Prime modulus of the form `(m 2^k + 1)` to be used for performing
  154. **NTT** on the sequence.
  155. Examples
  156. ========
  157. >>> from sympy.discrete.convolutions import convolution_ntt
  158. >>> convolution_ntt([2, 3], [4, 5], prime=19*2**10 + 1)
  159. [8, 22, 15]
  160. >>> convolution_ntt([2, 5], [6, 7, 3], prime=19*2**10 + 1)
  161. [12, 44, 41, 15]
  162. >>> convolution_ntt([333, 555], [222, 666], prime=19*2**10 + 1)
  163. [15555, 14219, 19404]
  164. References
  165. ==========
  166. .. [1] https://en.wikipedia.org/wiki/Convolution_theorem
  167. .. [2] https://en.wikipedia.org/wiki/Discrete_Fourier_transform_(general%29
  168. """
  169. a, b, p = a[:], b[:], as_int(prime)
  170. n = m = len(a) + len(b) - 1 # convolution size
  171. if n > 0 and n&(n - 1): # not a power of 2
  172. n = 2**n.bit_length()
  173. # padding with zeros
  174. a += [0]*(n - len(a))
  175. b += [0]*(n - len(b))
  176. a, b = ntt(a, p), ntt(b, p)
  177. a = [x*y % p for x, y in zip(a, b)]
  178. a = intt(a, p)[:m]
  179. return a
  180. #----------------------------------------------------------------------------#
  181. # #
  182. # Convolution for 2**n-group #
  183. # #
  184. #----------------------------------------------------------------------------#
  185. def convolution_fwht(a, b):
  186. """
  187. Performs dyadic (*bitwise-XOR*) convolution using Fast Walsh Hadamard
  188. Transform.
  189. The convolution is automatically padded to the right with zeros, as the
  190. *radix-2 FWHT* requires the number of sample points to be a power of 2.
  191. Parameters
  192. ==========
  193. a, b : iterables
  194. The sequences for which convolution is performed.
  195. Examples
  196. ========
  197. >>> from sympy import symbols, S, I
  198. >>> from sympy.discrete.convolutions import convolution_fwht
  199. >>> u, v, x, y = symbols('u v x y')
  200. >>> convolution_fwht([u, v], [x, y])
  201. [u*x + v*y, u*y + v*x]
  202. >>> convolution_fwht([2, 3], [4, 5])
  203. [23, 22]
  204. >>> convolution_fwht([2, 5 + 4*I, 7], [6*I, 7, 3 + 4*I])
  205. [56 + 68*I, -10 + 30*I, 6 + 50*I, 48 + 32*I]
  206. >>> convolution_fwht([S(33)/7, S(55)/6, S(7)/4], [S(2)/3, 5])
  207. [2057/42, 1870/63, 7/6, 35/4]
  208. References
  209. ==========
  210. .. [1] https://www.radioeng.cz/fulltexts/2002/02_03_40_42.pdf
  211. .. [2] https://en.wikipedia.org/wiki/Hadamard_transform
  212. """
  213. if not a or not b:
  214. return []
  215. a, b = a[:], b[:]
  216. n = max(len(a), len(b))
  217. if n&(n - 1): # not a power of 2
  218. n = 2**n.bit_length()
  219. # padding with zeros
  220. a += [S.Zero]*(n - len(a))
  221. b += [S.Zero]*(n - len(b))
  222. a, b = fwht(a), fwht(b)
  223. a = [expand_mul(x*y) for x, y in zip(a, b)]
  224. a = ifwht(a)
  225. return a
  226. #----------------------------------------------------------------------------#
  227. # #
  228. # Subset Convolution #
  229. # #
  230. #----------------------------------------------------------------------------#
  231. def convolution_subset(a, b):
  232. """
  233. Performs Subset Convolution of given sequences.
  234. The indices of each argument, considered as bit strings, correspond to
  235. subsets of a finite set.
  236. The sequence is automatically padded to the right with zeros, as the
  237. definition of subset based on bitmasks (indices) requires the size of
  238. sequence to be a power of 2.
  239. Parameters
  240. ==========
  241. a, b : iterables
  242. The sequences for which convolution is performed.
  243. Examples
  244. ========
  245. >>> from sympy import symbols, S
  246. >>> from sympy.discrete.convolutions import convolution_subset
  247. >>> u, v, x, y, z = symbols('u v x y z')
  248. >>> convolution_subset([u, v], [x, y])
  249. [u*x, u*y + v*x]
  250. >>> convolution_subset([u, v, x], [y, z])
  251. [u*y, u*z + v*y, x*y, x*z]
  252. >>> convolution_subset([1, S(2)/3], [3, 4])
  253. [3, 6]
  254. >>> convolution_subset([1, 3, S(5)/7], [7])
  255. [7, 21, 5, 0]
  256. References
  257. ==========
  258. .. [1] https://people.csail.mit.edu/rrw/presentations/subset-conv.pdf
  259. """
  260. if not a or not b:
  261. return []
  262. if not iterable(a) or not iterable(b):
  263. raise TypeError("Expected a sequence of coefficients for convolution")
  264. a = [sympify(arg) for arg in a]
  265. b = [sympify(arg) for arg in b]
  266. n = max(len(a), len(b))
  267. if n&(n - 1): # not a power of 2
  268. n = 2**n.bit_length()
  269. # padding with zeros
  270. a += [S.Zero]*(n - len(a))
  271. b += [S.Zero]*(n - len(b))
  272. c = [S.Zero]*n
  273. for mask in range(n):
  274. smask = mask
  275. while smask > 0:
  276. c[mask] += expand_mul(a[smask] * b[mask^smask])
  277. smask = (smask - 1)&mask
  278. c[mask] += expand_mul(a[smask] * b[mask^smask])
  279. return c
  280. #----------------------------------------------------------------------------#
  281. # #
  282. # Covering Product #
  283. # #
  284. #----------------------------------------------------------------------------#
  285. def covering_product(a, b):
  286. """
  287. Returns the covering product of given sequences.
  288. The indices of each argument, considered as bit strings, correspond to
  289. subsets of a finite set.
  290. The covering product of given sequences is a sequence which contains
  291. the sum of products of the elements of the given sequences grouped by
  292. the *bitwise-OR* of the corresponding indices.
  293. The sequence is automatically padded to the right with zeros, as the
  294. definition of subset based on bitmasks (indices) requires the size of
  295. sequence to be a power of 2.
  296. Parameters
  297. ==========
  298. a, b : iterables
  299. The sequences for which covering product is to be obtained.
  300. Examples
  301. ========
  302. >>> from sympy import symbols, S, I, covering_product
  303. >>> u, v, x, y, z = symbols('u v x y z')
  304. >>> covering_product([u, v], [x, y])
  305. [u*x, u*y + v*x + v*y]
  306. >>> covering_product([u, v, x], [y, z])
  307. [u*y, u*z + v*y + v*z, x*y, x*z]
  308. >>> covering_product([1, S(2)/3], [3, 4 + 5*I])
  309. [3, 26/3 + 25*I/3]
  310. >>> covering_product([1, 3, S(5)/7], [7, 8])
  311. [7, 53, 5, 40/7]
  312. References
  313. ==========
  314. .. [1] https://people.csail.mit.edu/rrw/presentations/subset-conv.pdf
  315. """
  316. if not a or not b:
  317. return []
  318. a, b = a[:], b[:]
  319. n = max(len(a), len(b))
  320. if n&(n - 1): # not a power of 2
  321. n = 2**n.bit_length()
  322. # padding with zeros
  323. a += [S.Zero]*(n - len(a))
  324. b += [S.Zero]*(n - len(b))
  325. a, b = mobius_transform(a), mobius_transform(b)
  326. a = [expand_mul(x*y) for x, y in zip(a, b)]
  327. a = inverse_mobius_transform(a)
  328. return a
  329. #----------------------------------------------------------------------------#
  330. # #
  331. # Intersecting Product #
  332. # #
  333. #----------------------------------------------------------------------------#
  334. def intersecting_product(a, b):
  335. """
  336. Returns the intersecting product of given sequences.
  337. The indices of each argument, considered as bit strings, correspond to
  338. subsets of a finite set.
  339. The intersecting product of given sequences is the sequence which
  340. contains the sum of products of the elements of the given sequences
  341. grouped by the *bitwise-AND* of the corresponding indices.
  342. The sequence is automatically padded to the right with zeros, as the
  343. definition of subset based on bitmasks (indices) requires the size of
  344. sequence to be a power of 2.
  345. Parameters
  346. ==========
  347. a, b : iterables
  348. The sequences for which intersecting product is to be obtained.
  349. Examples
  350. ========
  351. >>> from sympy import symbols, S, I, intersecting_product
  352. >>> u, v, x, y, z = symbols('u v x y z')
  353. >>> intersecting_product([u, v], [x, y])
  354. [u*x + u*y + v*x, v*y]
  355. >>> intersecting_product([u, v, x], [y, z])
  356. [u*y + u*z + v*y + x*y + x*z, v*z, 0, 0]
  357. >>> intersecting_product([1, S(2)/3], [3, 4 + 5*I])
  358. [9 + 5*I, 8/3 + 10*I/3]
  359. >>> intersecting_product([1, 3, S(5)/7], [7, 8])
  360. [327/7, 24, 0, 0]
  361. References
  362. ==========
  363. .. [1] https://people.csail.mit.edu/rrw/presentations/subset-conv.pdf
  364. """
  365. if not a or not b:
  366. return []
  367. a, b = a[:], b[:]
  368. n = max(len(a), len(b))
  369. if n&(n - 1): # not a power of 2
  370. n = 2**n.bit_length()
  371. # padding with zeros
  372. a += [S.Zero]*(n - len(a))
  373. b += [S.Zero]*(n - len(b))
  374. a, b = mobius_transform(a, subset=False), mobius_transform(b, subset=False)
  375. a = [expand_mul(x*y) for x, y in zip(a, b)]
  376. a = inverse_mobius_transform(a, subset=False)
  377. return a
  378. #----------------------------------------------------------------------------#
  379. # #
  380. # Integer Convolutions #
  381. # #
  382. #----------------------------------------------------------------------------#
  383. def convolution_int(a, b):
  384. """Return the convolution of two sequences as a list.
  385. The iterables must consist solely of integers.
  386. Parameters
  387. ==========
  388. a, b : Sequence
  389. The sequences for which convolution is performed.
  390. Explanation
  391. ===========
  392. This function performs the convolution of ``a`` and ``b`` by packing
  393. each into a single integer, multiplying them together, and then
  394. unpacking the result from the product. The intuition behind this is
  395. that if we evaluate some polynomial [1]:
  396. .. math ::
  397. 1156x^6 + 3808x^5 + 8440x^4 + 14856x^3 + 16164x^2 + 14040x + 8100
  398. at say $x = 10^5$ we obtain $1156038080844014856161641404008100$.
  399. Note we can read of the coefficients for each term every five digits.
  400. If the $x$ we chose to evaluate at is large enough, the same will hold
  401. for the product.
  402. The idea now is since big integer multiplication in libraries such
  403. as GMP is highly optimised, this will be reasonably fast.
  404. Examples
  405. ========
  406. >>> from sympy.discrete.convolutions import convolution_int
  407. >>> convolution_int([2, 3], [4, 5])
  408. [8, 22, 15]
  409. >>> convolution_int([1, 1, -1], [1, 1])
  410. [1, 2, 0, -1]
  411. References
  412. ==========
  413. .. [1] Fateman, Richard J.
  414. Can you save time in multiplying polynomials by encoding them as integers?
  415. University of California, Berkeley, California (2004).
  416. https://people.eecs.berkeley.edu/~fateman/papers/polysbyGMP.pdf
  417. """
  418. # An upper bound on the largest coefficient in p(x)q(x) is given by (1 + min(dp, dq))N(p)N(q)
  419. # where dp = deg(p), dq = deg(q), N(f) denotes the coefficient of largest modulus in f [1]
  420. B = max(abs(c) for c in a)*max(abs(c) for c in b)*(1 + min(len(a) - 1, len(b) - 1))
  421. x, power = MPZ(1), 0
  422. while x <= (2*B): # multiply by two for negative coefficients, see [1]
  423. x <<= 1
  424. power += 1
  425. def to_integer(poly):
  426. n, mul = MPZ(0), 0
  427. for c in reversed(poly):
  428. if c and not mul: mul = -1 if c < 0 else 1
  429. n <<= power
  430. n += mul*int(c)
  431. return mul, n
  432. # Perform packing and multiplication
  433. (a_mul, a_packed), (b_mul, b_packed) = to_integer(a), to_integer(b)
  434. result = a_packed * b_packed
  435. # Perform unpacking
  436. mul = a_mul * b_mul
  437. mask, half, borrow, poly = x - 1, x >> 1, 0, []
  438. while result or borrow:
  439. coeff = (result & mask) + borrow
  440. result >>= power
  441. borrow = coeff >= half
  442. poly.append(mul * int(coeff if coeff < half else coeff - x))
  443. return poly or [0]