test_hyper.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. from sympy.core.containers import Tuple
  2. from sympy.core.function import Derivative
  3. from sympy.core.numbers import (I, Rational, oo, pi)
  4. from sympy.core.singleton import S
  5. from sympy.core.symbol import symbols
  6. from sympy.functions.elementary.exponential import (exp, log)
  7. from sympy.functions.elementary.miscellaneous import sqrt
  8. from sympy.functions.elementary.trigonometric import cos
  9. from sympy.functions.special.gamma_functions import gamma
  10. from sympy.functions.special.hyper import (appellf1, hyper, meijerg)
  11. from sympy.series.order import O
  12. from sympy.abc import x, z, k
  13. from sympy.series.limits import limit
  14. from sympy.testing.pytest import raises, slow
  15. from sympy.core.random import (
  16. random_complex_number as randcplx,
  17. verify_numerically as tn,
  18. test_derivative_numerically as td)
  19. def test_TupleParametersBase():
  20. # test that our implementation of the chain rule works
  21. p = hyper((), (), z**2)
  22. assert p.diff(z) == p*2*z
  23. def test_hyper():
  24. raises(TypeError, lambda: hyper(1, 2, z))
  25. assert hyper((2, 1), (1,), z) == hyper(Tuple(1, 2), Tuple(1), z)
  26. assert hyper((2, 1, 2), (1, 2, 1, 3), z) == hyper((2,), (1, 3), z)
  27. u = hyper((2, 1, 2), (1, 2, 1, 3), z, evaluate=False)
  28. assert u.ap == Tuple(1, 2, 2)
  29. assert u.bq == Tuple(1, 1, 2, 3)
  30. h = hyper((1, 2), (3, 4, 5), z)
  31. assert h.ap == Tuple(1, 2)
  32. assert h.bq == Tuple(3, 4, 5)
  33. assert h.argument == z
  34. assert h.is_commutative is True
  35. h = hyper((2, 1), (4, 3, 5), z)
  36. assert h.ap == Tuple(1, 2)
  37. assert h.bq == Tuple(3, 4, 5)
  38. assert h.argument == z
  39. assert h.is_commutative is True
  40. # just a few checks to make sure that all arguments go where they should
  41. assert tn(hyper(Tuple(), Tuple(), z), exp(z), z)
  42. assert tn(z*hyper((1, 1), Tuple(2), -z), log(1 + z), z)
  43. # differentiation
  44. h = hyper(
  45. (randcplx(), randcplx(), randcplx()), (randcplx(), randcplx()), z)
  46. assert td(h, z)
  47. a1, a2, b1, b2, b3 = symbols('a1:3, b1:4')
  48. assert hyper((a1, a2), (b1, b2, b3), z).diff(z) == \
  49. a1*a2/(b1*b2*b3) * hyper((a1 + 1, a2 + 1), (b1 + 1, b2 + 1, b3 + 1), z)
  50. # differentiation wrt parameters is not supported
  51. assert hyper([z], [], z).diff(z) == Derivative(hyper([z], [], z), z)
  52. # hyper is unbranched wrt parameters
  53. from sympy.functions.elementary.complexes import polar_lift
  54. assert hyper([polar_lift(z)], [polar_lift(k)], polar_lift(x)) == \
  55. hyper([z], [k], polar_lift(x))
  56. # hyper does not automatically evaluate anyway, but the test is to make
  57. # sure that the evaluate keyword is accepted
  58. assert hyper((1, 2), (1,), z, evaluate=False).func is hyper
  59. def test_expand_func():
  60. # evaluation at 1 of Gauss' hypergeometric function:
  61. from sympy.abc import a, b, c
  62. from sympy.core.function import expand_func
  63. a1, b1, c1 = randcplx(), randcplx(), randcplx() + 5
  64. assert expand_func(hyper([a, b], [c], 1)) == \
  65. gamma(c)*gamma(-a - b + c)/(gamma(-a + c)*gamma(-b + c))
  66. assert abs(expand_func(hyper([a1, b1], [c1], 1)).n()
  67. - hyper([a1, b1], [c1], 1).n()) < 1e-10
  68. # hyperexpand wrapper for hyper:
  69. assert expand_func(hyper([], [], z)) == exp(z)
  70. assert expand_func(hyper([1, 2, 3], [], z)) == hyper([1, 2, 3], [], z)
  71. assert expand_func(meijerg([[1, 1], []], [[1], [0]], z)) == log(z + 1)
  72. assert expand_func(meijerg([[1, 1], []], [[], []], z)) == \
  73. meijerg([[1, 1], []], [[], []], z)
  74. def replace_dummy(expr, sym):
  75. from sympy.core.symbol import Dummy
  76. dum = expr.atoms(Dummy)
  77. if not dum:
  78. return expr
  79. assert len(dum) == 1
  80. return expr.xreplace({dum.pop(): sym})
  81. def test_hyper_rewrite_sum():
  82. from sympy.concrete.summations import Sum
  83. from sympy.core.symbol import Dummy
  84. from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial)
  85. _k = Dummy("k")
  86. assert replace_dummy(hyper((1, 2), (1, 3), x).rewrite(Sum), _k) == \
  87. Sum(x**_k / factorial(_k) * RisingFactorial(2, _k) /
  88. RisingFactorial(3, _k), (_k, 0, oo))
  89. assert hyper((1, 2, 3), (-1, 3), z).rewrite(Sum) == \
  90. hyper((1, 2, 3), (-1, 3), z)
  91. def test_radius_of_convergence():
  92. assert hyper((1, 2), [3], z).radius_of_convergence == 1
  93. assert hyper((1, 2), [3, 4], z).radius_of_convergence is oo
  94. assert hyper((1, 2, 3), [4], z).radius_of_convergence == 0
  95. assert hyper((0, 1, 2), [4], z).radius_of_convergence is oo
  96. assert hyper((-1, 1, 2), [-4], z).radius_of_convergence == 0
  97. assert hyper((-1, -2, 2), [-1], z).radius_of_convergence is oo
  98. assert hyper((-1, 2), [-1, -2], z).radius_of_convergence == 0
  99. assert hyper([-1, 1, 3], [-2, 2], z).radius_of_convergence == 1
  100. assert hyper([-1, 1], [-2, 2], z).radius_of_convergence is oo
  101. assert hyper([-1, 1, 3], [-2], z).radius_of_convergence == 0
  102. assert hyper((-1, 2, 3, 4), [], z).radius_of_convergence is oo
  103. assert hyper([1, 1], [3], 1).convergence_statement == True
  104. assert hyper([1, 1], [2], 1).convergence_statement == False
  105. assert hyper([1, 1], [2], -1).convergence_statement == True
  106. assert hyper([1, 1], [1], -1).convergence_statement == False
  107. def test_meijer():
  108. raises(TypeError, lambda: meijerg(1, z))
  109. raises(TypeError, lambda: meijerg(((1,), (2,)), (3,), (4,), z))
  110. assert meijerg(((1, 2), (3,)), ((4,), (5,)), z) == \
  111. meijerg(Tuple(1, 2), Tuple(3), Tuple(4), Tuple(5), z)
  112. g = meijerg((1, 2), (3, 4, 5), (6, 7, 8, 9), (10, 11, 12, 13, 14), z)
  113. assert g.an == Tuple(1, 2)
  114. assert g.ap == Tuple(1, 2, 3, 4, 5)
  115. assert g.aother == Tuple(3, 4, 5)
  116. assert g.bm == Tuple(6, 7, 8, 9)
  117. assert g.bq == Tuple(6, 7, 8, 9, 10, 11, 12, 13, 14)
  118. assert g.bother == Tuple(10, 11, 12, 13, 14)
  119. assert g.argument == z
  120. assert g.nu == 75
  121. assert g.delta == -1
  122. assert g.is_commutative is True
  123. assert g.is_number is False
  124. #issue 13071
  125. assert meijerg([[],[]], [[S.Half],[0]], 1).is_number is True
  126. assert meijerg([1, 2], [3], [4], [5], z).delta == S.Half
  127. # just a few checks to make sure that all arguments go where they should
  128. assert tn(meijerg(Tuple(), Tuple(), Tuple(0), Tuple(), -z), exp(z), z)
  129. assert tn(sqrt(pi)*meijerg(Tuple(), Tuple(),
  130. Tuple(0), Tuple(S.Half), z**2/4), cos(z), z)
  131. assert tn(meijerg(Tuple(1, 1), Tuple(), Tuple(1), Tuple(0), z),
  132. log(1 + z), z)
  133. # test exceptions
  134. raises(ValueError, lambda: meijerg(((3, 1), (2,)), ((oo,), (2, 0)), x))
  135. raises(ValueError, lambda: meijerg(((3, 1), (2,)), ((1,), (2, 0)), x))
  136. # differentiation
  137. g = meijerg((randcplx(),), (randcplx() + 2*I,), Tuple(),
  138. (randcplx(), randcplx()), z)
  139. assert td(g, z)
  140. g = meijerg(Tuple(), (randcplx(),), Tuple(),
  141. (randcplx(), randcplx()), z)
  142. assert td(g, z)
  143. g = meijerg(Tuple(), Tuple(), Tuple(randcplx()),
  144. Tuple(randcplx(), randcplx()), z)
  145. assert td(g, z)
  146. a1, a2, b1, b2, c1, c2, d1, d2 = symbols('a1:3, b1:3, c1:3, d1:3')
  147. assert meijerg((a1, a2), (b1, b2), (c1, c2), (d1, d2), z).diff(z) == \
  148. (meijerg((a1 - 1, a2), (b1, b2), (c1, c2), (d1, d2), z)
  149. + (a1 - 1)*meijerg((a1, a2), (b1, b2), (c1, c2), (d1, d2), z))/z
  150. assert meijerg([z, z], [], [], [], z).diff(z) == \
  151. Derivative(meijerg([z, z], [], [], [], z), z)
  152. # meijerg is unbranched wrt parameters
  153. from sympy.functions.elementary.complexes import polar_lift as pl
  154. assert meijerg([pl(a1)], [pl(a2)], [pl(b1)], [pl(b2)], pl(z)) == \
  155. meijerg([a1], [a2], [b1], [b2], pl(z))
  156. # integrand
  157. from sympy.abc import a, b, c, d, s
  158. assert meijerg([a], [b], [c], [d], z).integrand(s) == \
  159. z**s*gamma(c - s)*gamma(-a + s + 1)/(gamma(b - s)*gamma(-d + s + 1))
  160. def test_meijerg_derivative():
  161. assert meijerg([], [1, 1], [0, 0, x], [], z).diff(x) == \
  162. log(z)*meijerg([], [1, 1], [0, 0, x], [], z) \
  163. + 2*meijerg([], [1, 1, 1], [0, 0, x, 0], [], z)
  164. y = randcplx()
  165. a = 5 # mpmath chokes with non-real numbers, and Mod1 with floats
  166. assert td(meijerg([x], [], [], [], y), x)
  167. assert td(meijerg([x**2], [], [], [], y), x)
  168. assert td(meijerg([], [x], [], [], y), x)
  169. assert td(meijerg([], [], [x], [], y), x)
  170. assert td(meijerg([], [], [], [x], y), x)
  171. assert td(meijerg([x], [a], [a + 1], [], y), x)
  172. assert td(meijerg([x], [a + 1], [a], [], y), x)
  173. assert td(meijerg([x, a], [], [], [a + 1], y), x)
  174. assert td(meijerg([x, a + 1], [], [], [a], y), x)
  175. b = Rational(3, 2)
  176. assert td(meijerg([a + 2], [b], [b - 3, x], [a], y), x)
  177. def test_meijerg_period():
  178. assert meijerg([], [1], [0], [], x).get_period() == 2*pi
  179. assert meijerg([1], [], [], [0], x).get_period() == 2*pi
  180. assert meijerg([], [], [0], [], x).get_period() == 2*pi # exp(x)
  181. assert meijerg(
  182. [], [], [0], [S.Half], x).get_period() == 2*pi # cos(sqrt(x))
  183. assert meijerg(
  184. [], [], [S.Half], [0], x).get_period() == 4*pi # sin(sqrt(x))
  185. assert meijerg([1, 1], [], [1], [0], x).get_period() is oo # log(1 + x)
  186. def test_hyper_unpolarify():
  187. from sympy.functions.elementary.exponential import exp_polar
  188. a = exp_polar(2*pi*I)*x
  189. b = x
  190. assert hyper([], [], a).argument == b
  191. assert hyper([0], [], a).argument == a
  192. assert hyper([0], [0], a).argument == b
  193. assert hyper([0, 1], [0], a).argument == a
  194. assert hyper([0, 1], [0], exp_polar(2*pi*I)).argument == 1
  195. @slow
  196. def test_hyperrep():
  197. from sympy.functions.special.hyper import (HyperRep, HyperRep_atanh,
  198. HyperRep_power1, HyperRep_power2, HyperRep_log1, HyperRep_asin1,
  199. HyperRep_asin2, HyperRep_sqrts1, HyperRep_sqrts2, HyperRep_log2,
  200. HyperRep_cosasin, HyperRep_sinasin)
  201. # First test the base class works.
  202. from sympy.functions.elementary.exponential import exp_polar
  203. from sympy.functions.elementary.piecewise import Piecewise
  204. a, b, c, d, z = symbols('a b c d z')
  205. class myrep(HyperRep):
  206. @classmethod
  207. def _expr_small(cls, x):
  208. return a
  209. @classmethod
  210. def _expr_small_minus(cls, x):
  211. return b
  212. @classmethod
  213. def _expr_big(cls, x, n):
  214. return c*n
  215. @classmethod
  216. def _expr_big_minus(cls, x, n):
  217. return d*n
  218. assert myrep(z).rewrite('nonrep') == Piecewise((0, abs(z) > 1), (a, True))
  219. assert myrep(exp_polar(I*pi)*z).rewrite('nonrep') == \
  220. Piecewise((0, abs(z) > 1), (b, True))
  221. assert myrep(exp_polar(2*I*pi)*z).rewrite('nonrep') == \
  222. Piecewise((c, abs(z) > 1), (a, True))
  223. assert myrep(exp_polar(3*I*pi)*z).rewrite('nonrep') == \
  224. Piecewise((d, abs(z) > 1), (b, True))
  225. assert myrep(exp_polar(4*I*pi)*z).rewrite('nonrep') == \
  226. Piecewise((2*c, abs(z) > 1), (a, True))
  227. assert myrep(exp_polar(5*I*pi)*z).rewrite('nonrep') == \
  228. Piecewise((2*d, abs(z) > 1), (b, True))
  229. assert myrep(z).rewrite('nonrepsmall') == a
  230. assert myrep(exp_polar(I*pi)*z).rewrite('nonrepsmall') == b
  231. def t(func, hyp, z):
  232. """ Test that func is a valid representation of hyp. """
  233. # First test that func agrees with hyp for small z
  234. if not tn(func.rewrite('nonrepsmall'), hyp, z,
  235. a=Rational(-1, 2), b=Rational(-1, 2), c=S.Half, d=S.Half):
  236. return False
  237. # Next check that the two small representations agree.
  238. if not tn(
  239. func.rewrite('nonrepsmall').subs(
  240. z, exp_polar(I*pi)*z).replace(exp_polar, exp),
  241. func.subs(z, exp_polar(I*pi)*z).rewrite('nonrepsmall'),
  242. z, a=Rational(-1, 2), b=Rational(-1, 2), c=S.Half, d=S.Half):
  243. return False
  244. # Next check continuity along exp_polar(I*pi)*t
  245. expr = func.subs(z, exp_polar(I*pi)*z).rewrite('nonrep')
  246. if abs(expr.subs(z, 1 + 1e-15).n() - expr.subs(z, 1 - 1e-15).n()) > 1e-10:
  247. return False
  248. # Finally check continuity of the big reps.
  249. def dosubs(func, a, b):
  250. rv = func.subs(z, exp_polar(a)*z).rewrite('nonrep')
  251. return rv.subs(z, exp_polar(b)*z).replace(exp_polar, exp)
  252. for n in [0, 1, 2, 3, 4, -1, -2, -3, -4]:
  253. expr1 = dosubs(func, 2*I*pi*n, I*pi/2)
  254. expr2 = dosubs(func, 2*I*pi*n + I*pi, -I*pi/2)
  255. if not tn(expr1, expr2, z):
  256. return False
  257. expr1 = dosubs(func, 2*I*pi*(n + 1), -I*pi/2)
  258. expr2 = dosubs(func, 2*I*pi*n + I*pi, I*pi/2)
  259. if not tn(expr1, expr2, z):
  260. return False
  261. return True
  262. # Now test the various representatives.
  263. a = Rational(1, 3)
  264. assert t(HyperRep_atanh(z), hyper([S.Half, 1], [Rational(3, 2)], z), z)
  265. assert t(HyperRep_power1(a, z), hyper([-a], [], z), z)
  266. assert t(HyperRep_power2(a, z), hyper([a, a - S.Half], [2*a], z), z)
  267. assert t(HyperRep_log1(z), -z*hyper([1, 1], [2], z), z)
  268. assert t(HyperRep_asin1(z), hyper([S.Half, S.Half], [Rational(3, 2)], z), z)
  269. assert t(HyperRep_asin2(z), hyper([1, 1], [Rational(3, 2)], z), z)
  270. assert t(HyperRep_sqrts1(a, z), hyper([-a, S.Half - a], [S.Half], z), z)
  271. assert t(HyperRep_sqrts2(a, z),
  272. -2*z/(2*a + 1)*hyper([-a - S.Half, -a], [S.Half], z).diff(z), z)
  273. assert t(HyperRep_log2(z), -z/4*hyper([Rational(3, 2), 1, 1], [2, 2], z), z)
  274. assert t(HyperRep_cosasin(a, z), hyper([-a, a], [S.Half], z), z)
  275. assert t(HyperRep_sinasin(a, z), 2*a*z*hyper([1 - a, 1 + a], [Rational(3, 2)], z), z)
  276. @slow
  277. def test_meijerg_eval():
  278. from sympy.functions.elementary.exponential import exp_polar
  279. from sympy.functions.special.bessel import besseli
  280. from sympy.abc import l
  281. a = randcplx()
  282. arg = x*exp_polar(k*pi*I)
  283. expr1 = pi*meijerg([[], [(a + 1)/2]], [[a/2], [-a/2, (a + 1)/2]], arg**2/4)
  284. expr2 = besseli(a, arg)
  285. # Test that the two expressions agree for all arguments.
  286. for x_ in [0.5, 1.5]:
  287. for k_ in [0.0, 0.1, 0.3, 0.5, 0.8, 1, 5.751, 15.3]:
  288. assert abs((expr1 - expr2).n(subs={x: x_, k: k_})) < 1e-10
  289. assert abs((expr1 - expr2).n(subs={x: x_, k: -k_})) < 1e-10
  290. # Test continuity independently
  291. eps = 1e-13
  292. expr2 = expr1.subs(k, l)
  293. for x_ in [0.5, 1.5]:
  294. for k_ in [0.5, Rational(1, 3), 0.25, 0.75, Rational(2, 3), 1.0, 1.5]:
  295. assert abs((expr1 - expr2).n(
  296. subs={x: x_, k: k_ + eps, l: k_ - eps})) < 1e-10
  297. assert abs((expr1 - expr2).n(
  298. subs={x: x_, k: -k_ + eps, l: -k_ - eps})) < 1e-10
  299. expr = (meijerg(((0.5,), ()), ((0.5, 0, 0.5), ()), exp_polar(-I*pi)/4)
  300. + meijerg(((0.5,), ()), ((0.5, 0, 0.5), ()), exp_polar(I*pi)/4)) \
  301. /(2*sqrt(pi))
  302. assert (expr - pi/exp(1)).n(chop=True) == 0
  303. def test_limits():
  304. k, x = symbols('k, x')
  305. assert hyper((1,), (Rational(4, 3), Rational(5, 3)), k**2).series(k) == \
  306. 1 + 9*k**2/20 + 81*k**4/1120 + O(k**6) # issue 6350
  307. # https://github.com/sympy/sympy/issues/11465
  308. assert limit(1/hyper((1, ), (1, ), x), x, 0) == 1
  309. def test_appellf1():
  310. a, b1, b2, c, x, y = symbols('a b1 b2 c x y')
  311. assert appellf1(a, b2, b1, c, y, x) == appellf1(a, b1, b2, c, x, y)
  312. assert appellf1(a, b1, b1, c, y, x) == appellf1(a, b1, b1, c, x, y)
  313. assert appellf1(a, b1, b2, c, S.Zero, S.Zero) is S.One
  314. f = appellf1(a, b1, b2, c, S.Zero, S.Zero, evaluate=False)
  315. assert f.func is appellf1
  316. assert f.doit() is S.One
  317. def test_derivative_appellf1():
  318. from sympy.core.function import diff
  319. a, b1, b2, c, x, y, z = symbols('a b1 b2 c x y z')
  320. assert diff(appellf1(a, b1, b2, c, x, y), x) == a*b1*appellf1(a + 1, b2, b1 + 1, c + 1, y, x)/c
  321. assert diff(appellf1(a, b1, b2, c, x, y), y) == a*b2*appellf1(a + 1, b1, b2 + 1, c + 1, x, y)/c
  322. assert diff(appellf1(a, b1, b2, c, x, y), z) == 0
  323. assert diff(appellf1(a, b1, b2, c, x, y), a) == Derivative(appellf1(a, b1, b2, c, x, y), a)
  324. def test_eval_nseries():
  325. a1, b1, a2, b2 = symbols('a1 b1 a2 b2')
  326. assert hyper((1,2), (1,2,3), x**2)._eval_nseries(x, 7, None) == \
  327. 1 + x**2/3 + x**4/24 + x**6/360 + O(x**7)
  328. assert exp(x)._eval_nseries(x,7,None) == \
  329. hyper((a1, b1), (a1, b1), x)._eval_nseries(x, 7, None)
  330. assert hyper((a1, a2), (b1, b2), x)._eval_nseries(z, 7, None) ==\
  331. hyper((a1, a2), (b1, b2), x) + O(z**7)
  332. assert hyper((-S(1)/2, S(1)/2), (1,), 4*x/(x + 1)).nseries(x) == \
  333. 1 - x + x**2/4 - 3*x**3/4 - 15*x**4/64 - 93*x**5/64 + O(x**6)
  334. assert (pi/2*hyper((-S(1)/2, S(1)/2), (1,), 4*x/(x + 1))).nseries(x) == \
  335. pi/2 - pi*x/2 + pi*x**2/8 - 3*pi*x**3/8 - 15*pi*x**4/128 - 93*pi*x**5/128 + O(x**6)