gammasimp.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493
  1. from sympy.core import Function, S, Mul, Pow, Add
  2. from sympy.core.sorting import ordered, default_sort_key
  3. from sympy.core.function import expand_func
  4. from sympy.core.symbol import Dummy
  5. from sympy.functions import gamma, sqrt, sin
  6. from sympy.polys import factor, cancel
  7. from sympy.utilities.iterables import sift, uniq
  8. def gammasimp(expr):
  9. r"""
  10. Simplify expressions with gamma functions.
  11. Explanation
  12. ===========
  13. This function takes as input an expression containing gamma
  14. functions or functions that can be rewritten in terms of gamma
  15. functions and tries to minimize the number of those functions and
  16. reduce the size of their arguments.
  17. The algorithm works by rewriting all gamma functions as expressions
  18. involving rising factorials (Pochhammer symbols) and applies
  19. recurrence relations and other transformations applicable to rising
  20. factorials, to reduce their arguments, possibly letting the resulting
  21. rising factorial to cancel. Rising factorials with the second argument
  22. being an integer are expanded into polynomial forms and finally all
  23. other rising factorial are rewritten in terms of gamma functions.
  24. Then the following two steps are performed.
  25. 1. Reduce the number of gammas by applying the reflection theorem
  26. gamma(x)*gamma(1-x) == pi/sin(pi*x).
  27. 2. Reduce the number of gammas by applying the multiplication theorem
  28. gamma(x)*gamma(x+1/n)*...*gamma(x+(n-1)/n) == C*gamma(n*x).
  29. It then reduces the number of prefactors by absorbing them into gammas
  30. where possible and expands gammas with rational argument.
  31. All transformation rules can be found (or were derived from) here:
  32. .. [1] https://functions.wolfram.com/GammaBetaErf/Pochhammer/17/01/02/
  33. .. [2] https://functions.wolfram.com/GammaBetaErf/Pochhammer/27/01/0005/
  34. Examples
  35. ========
  36. >>> from sympy.simplify import gammasimp
  37. >>> from sympy import gamma, Symbol
  38. >>> from sympy.abc import x
  39. >>> n = Symbol('n', integer = True)
  40. >>> gammasimp(gamma(x)/gamma(x - 3))
  41. (x - 3)*(x - 2)*(x - 1)
  42. >>> gammasimp(gamma(n + 3))
  43. gamma(n + 3)
  44. """
  45. expr = expr.rewrite(gamma)
  46. # compute_ST will be looking for Functions and we don't want
  47. # it looking for non-gamma functions: issue 22606
  48. # so we mask free, non-gamma functions
  49. f = expr.atoms(Function)
  50. # take out gammas
  51. gammas = {i for i in f if isinstance(i, gamma)}
  52. if not gammas:
  53. return expr # avoid side effects like factoring
  54. f -= gammas
  55. # keep only those without bound symbols
  56. f = f & expr.as_dummy().atoms(Function)
  57. if f:
  58. dum, fun, simp = zip(*[
  59. (Dummy(), fi, fi.func(*[
  60. _gammasimp(a, as_comb=False) for a in fi.args]))
  61. for fi in ordered(f)])
  62. d = expr.xreplace(dict(zip(fun, dum)))
  63. return _gammasimp(d, as_comb=False).xreplace(dict(zip(dum, simp)))
  64. return _gammasimp(expr, as_comb=False)
  65. def _gammasimp(expr, as_comb):
  66. """
  67. Helper function for gammasimp and combsimp.
  68. Explanation
  69. ===========
  70. Simplifies expressions written in terms of gamma function. If
  71. as_comb is True, it tries to preserve integer arguments. See
  72. docstring of gammasimp for more information. This was part of
  73. combsimp() in combsimp.py.
  74. """
  75. expr = expr.replace(gamma,
  76. lambda n: _rf(1, (n - 1).expand()))
  77. if as_comb:
  78. expr = expr.replace(_rf,
  79. lambda a, b: gamma(b + 1))
  80. else:
  81. expr = expr.replace(_rf,
  82. lambda a, b: gamma(a + b)/gamma(a))
  83. def rule_gamma(expr, level=0):
  84. """ Simplify products of gamma functions further. """
  85. if expr.is_Atom:
  86. return expr
  87. def gamma_rat(x):
  88. # helper to simplify ratios of gammas
  89. was = x.count(gamma)
  90. xx = x.replace(gamma, lambda n: _rf(1, (n - 1).expand()
  91. ).replace(_rf, lambda a, b: gamma(a + b)/gamma(a)))
  92. if xx.count(gamma) < was:
  93. x = xx
  94. return x
  95. def gamma_factor(x):
  96. # return True if there is a gamma factor in shallow args
  97. if isinstance(x, gamma):
  98. return True
  99. if x.is_Add or x.is_Mul:
  100. return any(gamma_factor(xi) for xi in x.args)
  101. if x.is_Pow and (x.exp.is_integer or x.base.is_positive):
  102. return gamma_factor(x.base)
  103. return False
  104. # recursion step
  105. if level == 0:
  106. expr = expr.func(*[rule_gamma(x, level + 1) for x in expr.args])
  107. level += 1
  108. if not expr.is_Mul:
  109. return expr
  110. # non-commutative step
  111. if level == 1:
  112. args, nc = expr.args_cnc()
  113. if not args:
  114. return expr
  115. if nc:
  116. return rule_gamma(Mul._from_args(args), level + 1)*Mul._from_args(nc)
  117. level += 1
  118. # pure gamma handling, not factor absorption
  119. if level == 2:
  120. T, F = sift(expr.args, gamma_factor, binary=True)
  121. gamma_ind = Mul(*F)
  122. d = Mul(*T)
  123. nd, dd = d.as_numer_denom()
  124. for ipass in range(2):
  125. args = list(ordered(Mul.make_args(nd)))
  126. for i, ni in enumerate(args):
  127. if ni.is_Add:
  128. ni, dd = Add(*[
  129. rule_gamma(gamma_rat(a/dd), level + 1) for a in ni.args]
  130. ).as_numer_denom()
  131. args[i] = ni
  132. if not dd.has(gamma):
  133. break
  134. nd = Mul(*args)
  135. if ipass == 0 and not gamma_factor(nd):
  136. break
  137. nd, dd = dd, nd # now process in reversed order
  138. expr = gamma_ind*nd/dd
  139. if not (expr.is_Mul and (gamma_factor(dd) or gamma_factor(nd))):
  140. return expr
  141. level += 1
  142. # iteration until constant
  143. if level == 3:
  144. while True:
  145. was = expr
  146. expr = rule_gamma(expr, 4)
  147. if expr == was:
  148. return expr
  149. numer_gammas = []
  150. denom_gammas = []
  151. numer_others = []
  152. denom_others = []
  153. def explicate(p):
  154. if p is S.One:
  155. return None, []
  156. b, e = p.as_base_exp()
  157. if e.is_Integer:
  158. if isinstance(b, gamma):
  159. return True, [b.args[0]]*e
  160. else:
  161. return False, [b]*e
  162. else:
  163. return False, [p]
  164. newargs = list(ordered(expr.args))
  165. while newargs:
  166. n, d = newargs.pop().as_numer_denom()
  167. isg, l = explicate(n)
  168. if isg:
  169. numer_gammas.extend(l)
  170. elif isg is False:
  171. numer_others.extend(l)
  172. isg, l = explicate(d)
  173. if isg:
  174. denom_gammas.extend(l)
  175. elif isg is False:
  176. denom_others.extend(l)
  177. # =========== level 2 work: pure gamma manipulation =========
  178. if not as_comb:
  179. # Try to reduce the number of gamma factors by applying the
  180. # reflection formula gamma(x)*gamma(1-x) = pi/sin(pi*x)
  181. for gammas, numer, denom in [(
  182. numer_gammas, numer_others, denom_others),
  183. (denom_gammas, denom_others, numer_others)]:
  184. new = []
  185. while gammas:
  186. g1 = gammas.pop()
  187. if g1.is_integer:
  188. new.append(g1)
  189. continue
  190. for i, g2 in enumerate(gammas):
  191. n = g1 + g2 - 1
  192. if not n.is_Integer:
  193. continue
  194. numer.append(S.Pi)
  195. denom.append(sin(S.Pi*g1))
  196. gammas.pop(i)
  197. if n > 0:
  198. numer.extend(1 - g1 + k for k in range(n))
  199. elif n < 0:
  200. denom.extend(-g1 - k for k in range(-n))
  201. break
  202. else:
  203. new.append(g1)
  204. # /!\ updating IN PLACE
  205. gammas[:] = new
  206. # Try to reduce the number of gammas by using the duplication
  207. # theorem to cancel an upper and lower: gamma(2*s)/gamma(s) =
  208. # 2**(2*s + 1)/(4*sqrt(pi))*gamma(s + 1/2). Although this could
  209. # be done with higher argument ratios like gamma(3*x)/gamma(x),
  210. # this would not reduce the number of gammas as in this case.
  211. for ng, dg, no, do in [(numer_gammas, denom_gammas, numer_others,
  212. denom_others),
  213. (denom_gammas, numer_gammas, denom_others,
  214. numer_others)]:
  215. while True:
  216. for x in ng:
  217. for y in dg:
  218. n = x - 2*y
  219. if n.is_Integer:
  220. break
  221. else:
  222. continue
  223. break
  224. else:
  225. break
  226. ng.remove(x)
  227. dg.remove(y)
  228. if n > 0:
  229. no.extend(2*y + k for k in range(n))
  230. elif n < 0:
  231. do.extend(2*y - 1 - k for k in range(-n))
  232. ng.append(y + S.Half)
  233. no.append(2**(2*y - 1))
  234. do.append(sqrt(S.Pi))
  235. # Try to reduce the number of gamma factors by applying the
  236. # multiplication theorem (used when n gammas with args differing
  237. # by 1/n mod 1 are encountered).
  238. #
  239. # run of 2 with args differing by 1/2
  240. #
  241. # >>> gammasimp(gamma(x)*gamma(x+S.Half))
  242. # 2*sqrt(2)*2**(-2*x - 1/2)*sqrt(pi)*gamma(2*x)
  243. #
  244. # run of 3 args differing by 1/3 (mod 1)
  245. #
  246. # >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(2)/3))
  247. # 6*3**(-3*x - 1/2)*pi*gamma(3*x)
  248. # >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(5)/3))
  249. # 2*3**(-3*x - 1/2)*pi*(3*x + 2)*gamma(3*x)
  250. #
  251. def _run(coeffs):
  252. # find runs in coeffs such that the difference in terms (mod 1)
  253. # of t1, t2, ..., tn is 1/n
  254. u = list(uniq(coeffs))
  255. for i in range(len(u)):
  256. dj = ([((u[j] - u[i]) % 1, j) for j in range(i + 1, len(u))])
  257. for one, j in dj:
  258. if one.p == 1 and one.q != 1:
  259. n = one.q
  260. got = [i]
  261. get = list(range(1, n))
  262. for d, j in dj:
  263. m = n*d
  264. if m.is_Integer and m in get:
  265. get.remove(m)
  266. got.append(j)
  267. if not get:
  268. break
  269. else:
  270. continue
  271. for i, j in enumerate(got):
  272. c = u[j]
  273. coeffs.remove(c)
  274. got[i] = c
  275. return one.q, got[0], got[1:]
  276. def _mult_thm(gammas, numer, denom):
  277. # pull off and analyze the leading coefficient from each gamma arg
  278. # looking for runs in those Rationals
  279. # expr -> coeff + resid -> rats[resid] = coeff
  280. rats = {}
  281. for g in gammas:
  282. c, resid = g.as_coeff_Add()
  283. rats.setdefault(resid, []).append(c)
  284. # look for runs in Rationals for each resid
  285. keys = sorted(rats, key=default_sort_key)
  286. for resid in keys:
  287. coeffs = sorted(rats[resid])
  288. new = []
  289. while True:
  290. run = _run(coeffs)
  291. if run is None:
  292. break
  293. # process the sequence that was found:
  294. # 1) convert all the gamma functions to have the right
  295. # argument (could be off by an integer)
  296. # 2) append the factors corresponding to the theorem
  297. # 3) append the new gamma function
  298. n, ui, other = run
  299. # (1)
  300. for u in other:
  301. con = resid + u - 1
  302. for k in range(int(u - ui)):
  303. numer.append(con - k)
  304. con = n*(resid + ui) # for (2) and (3)
  305. # (2)
  306. numer.append((2*S.Pi)**(S(n - 1)/2)*
  307. n**(S.Half - con))
  308. # (3)
  309. new.append(con)
  310. # restore resid to coeffs
  311. rats[resid] = [resid + c for c in coeffs] + new
  312. # rebuild the gamma arguments
  313. g = []
  314. for resid in keys:
  315. g += rats[resid]
  316. # /!\ updating IN PLACE
  317. gammas[:] = g
  318. for l, numer, denom in [(numer_gammas, numer_others, denom_others),
  319. (denom_gammas, denom_others, numer_others)]:
  320. _mult_thm(l, numer, denom)
  321. # =========== level >= 2 work: factor absorption =========
  322. if level >= 2:
  323. # Try to absorb factors into the gammas: x*gamma(x) -> gamma(x + 1)
  324. # and gamma(x)/(x - 1) -> gamma(x - 1)
  325. # This code (in particular repeated calls to find_fuzzy) can be very
  326. # slow.
  327. def find_fuzzy(l, x):
  328. if not l:
  329. return
  330. S1, T1 = compute_ST(x)
  331. for y in l:
  332. S2, T2 = inv[y]
  333. if T1 != T2 or (not S1.intersection(S2) and
  334. (S1 != set() or S2 != set())):
  335. continue
  336. # XXX we want some simplification (e.g. cancel or
  337. # simplify) but no matter what it's slow.
  338. a = len(cancel(x/y).free_symbols)
  339. b = len(x.free_symbols)
  340. c = len(y.free_symbols)
  341. # TODO is there a better heuristic?
  342. if a == 0 and (b > 0 or c > 0):
  343. return y
  344. # We thus try to avoid expensive calls by building the following
  345. # "invariants": For every factor or gamma function argument
  346. # - the set of free symbols S
  347. # - the set of functional components T
  348. # We will only try to absorb if T1==T2 and (S1 intersect S2 != emptyset
  349. # or S1 == S2 == emptyset)
  350. inv = {}
  351. def compute_ST(expr):
  352. if expr in inv:
  353. return inv[expr]
  354. return (expr.free_symbols, expr.atoms(Function).union(
  355. {e.exp for e in expr.atoms(Pow)}))
  356. def update_ST(expr):
  357. inv[expr] = compute_ST(expr)
  358. for expr in numer_gammas + denom_gammas + numer_others + denom_others:
  359. update_ST(expr)
  360. for gammas, numer, denom in [(
  361. numer_gammas, numer_others, denom_others),
  362. (denom_gammas, denom_others, numer_others)]:
  363. new = []
  364. while gammas:
  365. g = gammas.pop()
  366. cont = True
  367. while cont:
  368. cont = False
  369. y = find_fuzzy(numer, g)
  370. if y is not None:
  371. numer.remove(y)
  372. if y != g:
  373. numer.append(y/g)
  374. update_ST(y/g)
  375. g += 1
  376. cont = True
  377. y = find_fuzzy(denom, g - 1)
  378. if y is not None:
  379. denom.remove(y)
  380. if y != g - 1:
  381. numer.append((g - 1)/y)
  382. update_ST((g - 1)/y)
  383. g -= 1
  384. cont = True
  385. new.append(g)
  386. # /!\ updating IN PLACE
  387. gammas[:] = new
  388. # =========== rebuild expr ==================================
  389. return Mul(*[gamma(g) for g in numer_gammas]) \
  390. / Mul(*[gamma(g) for g in denom_gammas]) \
  391. * Mul(*numer_others) / Mul(*denom_others)
  392. was = factor(expr)
  393. # (for some reason we cannot use Basic.replace in this case)
  394. expr = rule_gamma(was)
  395. if expr != was:
  396. expr = factor(expr)
  397. expr = expr.replace(gamma,
  398. lambda n: expand_func(gamma(n)) if n.is_Rational else gamma(n))
  399. return expr
  400. class _rf(Function):
  401. @classmethod
  402. def eval(cls, a, b):
  403. if b.is_Integer:
  404. if not b:
  405. return S.One
  406. n = int(b)
  407. if n > 0:
  408. return Mul(*[a + i for i in range(n)])
  409. elif n < 0:
  410. return 1/Mul(*[a - i for i in range(1, -n + 1)])
  411. else:
  412. if b.is_Add:
  413. c, _b = b.as_coeff_Add()
  414. if c.is_Integer:
  415. if c > 0:
  416. return _rf(a, _b)*_rf(a + _b, c)
  417. elif c < 0:
  418. return _rf(a, _b)/_rf(a + _b + c, -c)
  419. if a.is_Add:
  420. c, _a = a.as_coeff_Add()
  421. if c.is_Integer:
  422. if c > 0:
  423. return _rf(_a, b)*_rf(_a + b, c)/_rf(_a, c)
  424. elif c < 0:
  425. return _rf(_a, b)*_rf(_a + c, -c)/_rf(_a + b + c, -c)