test_radsimp.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. from sympy.core.add import Add
  2. from sympy.core.function import (Derivative, Function, diff)
  3. from sympy.core.mul import Mul
  4. from sympy.core.numbers import (I, Rational)
  5. from sympy.core.power import Pow
  6. from sympy.core.singleton import S
  7. from sympy.core.symbol import (Symbol, Wild, symbols)
  8. from sympy.functions.elementary.complexes import Abs
  9. from sympy.functions.elementary.exponential import (exp, log)
  10. from sympy.functions.elementary.miscellaneous import (root, sqrt)
  11. from sympy.functions.elementary.trigonometric import (cos, sin)
  12. from sympy.polys.polytools import factor
  13. from sympy.series.order import O
  14. from sympy.simplify.radsimp import (collect, collect_const, fraction, radsimp, rcollect)
  15. from sympy.core.expr import unchanged
  16. from sympy.core.mul import _unevaluated_Mul as umul
  17. from sympy.simplify.radsimp import (_unevaluated_Add,
  18. collect_sqrt, fraction_expand, collect_abs)
  19. from sympy.testing.pytest import raises
  20. from sympy.abc import x, y, z, a, b, c, d
  21. def test_radsimp():
  22. r2 = sqrt(2)
  23. r3 = sqrt(3)
  24. r5 = sqrt(5)
  25. r7 = sqrt(7)
  26. assert fraction(radsimp(1/r2)) == (sqrt(2), 2)
  27. assert radsimp(1/(1 + r2)) == \
  28. -1 + sqrt(2)
  29. assert radsimp(1/(r2 + r3)) == \
  30. -sqrt(2) + sqrt(3)
  31. assert fraction(radsimp(1/(1 + r2 + r3))) == \
  32. (-sqrt(6) + sqrt(2) + 2, 4)
  33. assert fraction(radsimp(1/(r2 + r3 + r5))) == \
  34. (-sqrt(30) + 2*sqrt(3) + 3*sqrt(2), 12)
  35. assert fraction(radsimp(1/(1 + r2 + r3 + r5))) == (
  36. (-34*sqrt(10) - 26*sqrt(15) - 55*sqrt(3) - 61*sqrt(2) + 14*sqrt(30) +
  37. 93 + 46*sqrt(6) + 53*sqrt(5), 71))
  38. assert fraction(radsimp(1/(r2 + r3 + r5 + r7))) == (
  39. (-50*sqrt(42) - 133*sqrt(5) - 34*sqrt(70) - 145*sqrt(3) + 22*sqrt(105)
  40. + 185*sqrt(2) + 62*sqrt(30) + 135*sqrt(7), 215))
  41. z = radsimp(1/(1 + r2/3 + r3/5 + r5 + r7))
  42. assert len((3616791619821680643598*z).args) == 16
  43. assert radsimp(1/z) == 1/z
  44. assert radsimp(1/z, max_terms=20).expand() == 1 + r2/3 + r3/5 + r5 + r7
  45. assert radsimp(1/(r2*3)) == \
  46. sqrt(2)/6
  47. assert radsimp(1/(r2*a + r3 + r5 + r7)) == (
  48. (8*sqrt(2)*a**7 - 8*sqrt(7)*a**6 - 8*sqrt(5)*a**6 - 8*sqrt(3)*a**6 -
  49. 180*sqrt(2)*a**5 + 8*sqrt(30)*a**5 + 8*sqrt(42)*a**5 + 8*sqrt(70)*a**5
  50. - 24*sqrt(105)*a**4 + 84*sqrt(3)*a**4 + 100*sqrt(5)*a**4 +
  51. 116*sqrt(7)*a**4 - 72*sqrt(70)*a**3 - 40*sqrt(42)*a**3 -
  52. 8*sqrt(30)*a**3 + 782*sqrt(2)*a**3 - 462*sqrt(3)*a**2 -
  53. 302*sqrt(7)*a**2 - 254*sqrt(5)*a**2 + 120*sqrt(105)*a**2 -
  54. 795*sqrt(2)*a - 62*sqrt(30)*a + 82*sqrt(42)*a + 98*sqrt(70)*a -
  55. 118*sqrt(105) + 59*sqrt(7) + 295*sqrt(5) + 531*sqrt(3))/(16*a**8 -
  56. 480*a**6 + 3128*a**4 - 6360*a**2 + 3481))
  57. assert radsimp(1/(r2*a + r2*b + r3 + r7)) == (
  58. (sqrt(2)*a*(a + b)**2 - 5*sqrt(2)*a + sqrt(42)*a + sqrt(2)*b*(a +
  59. b)**2 - 5*sqrt(2)*b + sqrt(42)*b - sqrt(7)*(a + b)**2 - sqrt(3)*(a +
  60. b)**2 - 2*sqrt(3) + 2*sqrt(7))/(2*a**4 + 8*a**3*b + 12*a**2*b**2 -
  61. 20*a**2 + 8*a*b**3 - 40*a*b + 2*b**4 - 20*b**2 + 8))
  62. assert radsimp(1/(r2*a + r2*b + r2*c + r2*d)) == \
  63. sqrt(2)/(2*a + 2*b + 2*c + 2*d)
  64. assert radsimp(1/(1 + r2*a + r2*b + r2*c + r2*d)) == (
  65. (sqrt(2)*a + sqrt(2)*b + sqrt(2)*c + sqrt(2)*d - 1)/(2*a**2 + 4*a*b +
  66. 4*a*c + 4*a*d + 2*b**2 + 4*b*c + 4*b*d + 2*c**2 + 4*c*d + 2*d**2 - 1))
  67. assert radsimp((y**2 - x)/(y - sqrt(x))) == \
  68. sqrt(x) + y
  69. assert radsimp(-(y**2 - x)/(y - sqrt(x))) == \
  70. -(sqrt(x) + y)
  71. assert radsimp(1/(1 - I + a*I)) == \
  72. (-I*a + 1 + I)/(a**2 - 2*a + 2)
  73. assert radsimp(1/((-x + y)*(x - sqrt(y)))) == \
  74. (-x - sqrt(y))/((x - y)*(x**2 - y))
  75. e = (3 + 3*sqrt(2))*x*(3*x - 3*sqrt(y))
  76. assert radsimp(e) == x*(3 + 3*sqrt(2))*(3*x - 3*sqrt(y))
  77. assert radsimp(1/e) == (
  78. (-9*x + 9*sqrt(2)*x - 9*sqrt(y) + 9*sqrt(2)*sqrt(y))/(9*x*(9*x**2 -
  79. 9*y)))
  80. assert radsimp(1 + 1/(1 + sqrt(3))) == \
  81. Mul(S.Half, -1 + sqrt(3), evaluate=False) + 1
  82. A = symbols("A", commutative=False)
  83. assert radsimp(x**2 + sqrt(2)*x**2 - sqrt(2)*x*A) == \
  84. x**2 + sqrt(2)*x**2 - sqrt(2)*x*A
  85. assert radsimp(1/sqrt(5 + 2 * sqrt(6))) == -sqrt(2) + sqrt(3)
  86. assert radsimp(1/sqrt(5 + 2 * sqrt(6))**3) == -(-sqrt(3) + sqrt(2))**3
  87. # issue 6532
  88. assert fraction(radsimp(1/sqrt(x))) == (sqrt(x), x)
  89. assert fraction(radsimp(1/sqrt(2*x + 3))) == (sqrt(2*x + 3), 2*x + 3)
  90. assert fraction(radsimp(1/sqrt(2*(x + 3)))) == (sqrt(2*x + 6), 2*x + 6)
  91. # issue 5994
  92. e = S('-(2 + 2*sqrt(2) + 4*2**(1/4))/'
  93. '(1 + 2**(3/4) + 3*2**(1/4) + 3*sqrt(2))')
  94. assert radsimp(e).expand() == -2*2**Rational(3, 4) - 2*2**Rational(1, 4) + 2 + 2*sqrt(2)
  95. # issue 5986 (modifications to radimp didn't initially recognize this so
  96. # the test is included here)
  97. assert radsimp(1/(-sqrt(5)/2 - S.Half + (-sqrt(5)/2 - S.Half)**2)) == 1
  98. # from issue 5934
  99. eq = (
  100. (-240*sqrt(2)*sqrt(sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) -
  101. 360*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) -
  102. 120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) +
  103. 120*sqrt(2)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) +
  104. 120*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5) +
  105. 120*sqrt(10)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) +
  106. 120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5))/(-36000 -
  107. 7200*sqrt(5) + (12*sqrt(10)*sqrt(sqrt(5) + 5) +
  108. 24*sqrt(10)*sqrt(-sqrt(5) + 5))**2))
  109. assert radsimp(eq) is S.NaN # it's 0/0
  110. # work with normal form
  111. e = 1/sqrt(sqrt(7)/7 + 2*sqrt(2) + 3*sqrt(3) + 5*sqrt(5)) + 3
  112. assert radsimp(e) == (
  113. -sqrt(sqrt(7) + 14*sqrt(2) + 21*sqrt(3) +
  114. 35*sqrt(5))*(-11654899*sqrt(35) - 1577436*sqrt(210) - 1278438*sqrt(15)
  115. - 1346996*sqrt(10) + 1635060*sqrt(6) + 5709765 + 7539830*sqrt(14) +
  116. 8291415*sqrt(21))/1300423175 + 3)
  117. # obey power rules
  118. base = sqrt(3) - sqrt(2)
  119. assert radsimp(1/base**3) == (sqrt(3) + sqrt(2))**3
  120. assert radsimp(1/(-base)**3) == -(sqrt(2) + sqrt(3))**3
  121. assert radsimp(1/(-base)**x) == (-base)**(-x)
  122. assert radsimp(1/base**x) == (sqrt(2) + sqrt(3))**x
  123. assert radsimp(root(1/(-1 - sqrt(2)), -x)) == (-1)**(-1/x)*(1 + sqrt(2))**(1/x)
  124. # recurse
  125. e = cos(1/(1 + sqrt(2)))
  126. assert radsimp(e) == cos(-sqrt(2) + 1)
  127. assert radsimp(e/2) == cos(-sqrt(2) + 1)/2
  128. assert radsimp(1/e) == 1/cos(-sqrt(2) + 1)
  129. assert radsimp(2/e) == 2/cos(-sqrt(2) + 1)
  130. assert fraction(radsimp(e/sqrt(x))) == (sqrt(x)*cos(-sqrt(2)+1), x)
  131. # test that symbolic denominators are not processed
  132. r = 1 + sqrt(2)
  133. assert radsimp(x/r, symbolic=False) == -x*(-sqrt(2) + 1)
  134. assert radsimp(x/(y + r), symbolic=False) == x/(y + 1 + sqrt(2))
  135. assert radsimp(x/(y + r)/r, symbolic=False) == \
  136. -x*(-sqrt(2) + 1)/(y + 1 + sqrt(2))
  137. # issue 7408
  138. eq = sqrt(x)/sqrt(y)
  139. assert radsimp(eq) == umul(sqrt(x), sqrt(y), 1/y)
  140. assert radsimp(eq, symbolic=False) == eq
  141. # issue 7498
  142. assert radsimp(sqrt(x)/sqrt(y)**3) == umul(sqrt(x), sqrt(y**3), 1/y**3)
  143. # for coverage
  144. eq = sqrt(x)/y**2
  145. assert radsimp(eq) == eq
  146. # handle non-Expr args
  147. from sympy.integrals.integrals import Integral
  148. eq = Integral(x/(sqrt(2) - 1), (x, 0, 1/(sqrt(2) + 1)))
  149. assert radsimp(eq) == Integral((sqrt(2) + 1)*x , (x, 0, sqrt(2) - 1))
  150. from sympy.sets import FiniteSet
  151. eq = FiniteSet(x/(sqrt(2) - 1))
  152. assert radsimp(eq) == FiniteSet((sqrt(2) + 1)*x)
  153. def test_radsimp_issue_3214():
  154. c, p = symbols('c p', positive=True)
  155. s = sqrt(c**2 - p**2)
  156. b = (c + I*p - s)/(c + I*p + s)
  157. assert radsimp(b) == -I*(c + I*p - sqrt(c**2 - p**2))**2/(2*c*p)
  158. def test_collect_1():
  159. """Collect with respect to Symbol"""
  160. x, y, z, n = symbols('x,y,z,n')
  161. assert collect(1, x) == 1
  162. assert collect( x + y*x, x ) == x * (1 + y)
  163. assert collect( x + x**2, x ) == x + x**2
  164. assert collect( x**2 + y*x**2, x ) == (x**2)*(1 + y)
  165. assert collect( x**2 + y*x, x ) == x*y + x**2
  166. assert collect( 2*x**2 + y*x**2 + 3*x*y, [x] ) == x**2*(2 + y) + 3*x*y
  167. assert collect( 2*x**2 + y*x**2 + 3*x*y, [y] ) == 2*x**2 + y*(x**2 + 3*x)
  168. assert collect( ((1 + y + x)**4).expand(), x) == ((1 + y)**4).expand() + \
  169. x*(4*(1 + y)**3).expand() + x**2*(6*(1 + y)**2).expand() + \
  170. x**3*(4*(1 + y)).expand() + x**4
  171. # symbols can be given as any iterable
  172. expr = x + y
  173. assert collect(expr, expr.free_symbols) == expr
  174. assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None
  175. ) == x*exp(x) + 3*x + (y + 2)*sin(x)
  176. assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x + y*x +
  177. y*x*exp(x), x, exact=None
  178. ) == x*exp(x)*(y + 1) + (3 + y)*x + (y + 2)*sin(x)
  179. def test_collect_2():
  180. """Collect with respect to a sum"""
  181. a, b, x = symbols('a,b,x')
  182. assert collect(a*(cos(x) + sin(x)) + b*(cos(x) + sin(x)),
  183. sin(x) + cos(x)) == (a + b)*(cos(x) + sin(x))
  184. def test_collect_3():
  185. """Collect with respect to a product"""
  186. a, b, c = symbols('a,b,c')
  187. f = Function('f')
  188. x, y, z, n = symbols('x,y,z,n')
  189. assert collect(-x/8 + x*y, -x) == x*(y - Rational(1, 8))
  190. assert collect( 1 + x*(y**2), x*y ) == 1 + x*(y**2)
  191. assert collect( x*y + a*x*y, x*y) == x*y*(1 + a)
  192. assert collect( 1 + x*y + a*x*y, x*y) == 1 + x*y*(1 + a)
  193. assert collect(a*x*f(x) + b*(x*f(x)), x*f(x)) == x*(a + b)*f(x)
  194. assert collect(a*x*log(x) + b*(x*log(x)), x*log(x)) == x*(a + b)*log(x)
  195. assert collect(a*x**2*log(x)**2 + b*(x*log(x))**2, x*log(x)) == \
  196. x**2*log(x)**2*(a + b)
  197. # with respect to a product of three symbols
  198. assert collect(y*x*z + a*x*y*z, x*y*z) == (1 + a)*x*y*z
  199. def test_collect_4():
  200. """Collect with respect to a power"""
  201. a, b, c, x = symbols('a,b,c,x')
  202. assert collect(a*x**c + b*x**c, x**c) == x**c*(a + b)
  203. # issue 6096: 2 stays with c (unless c is integer or x is positive0
  204. assert collect(a*x**(2*c) + b*x**(2*c), x**c) == x**(2*c)*(a + b)
  205. def test_collect_5():
  206. """Collect with respect to a tuple"""
  207. a, x, y, z, n = symbols('a,x,y,z,n')
  208. assert collect(x**2*y**4 + z*(x*y**2)**2 + z + a*z, [x*y**2, z]) in [
  209. z*(1 + a + x**2*y**4) + x**2*y**4,
  210. z*(1 + a) + x**2*y**4*(1 + z) ]
  211. assert collect((1 + (x + y) + (x + y)**2).expand(),
  212. [x, y]) == 1 + y + x*(1 + 2*y) + x**2 + y**2
  213. def test_collect_pr19431():
  214. """Unevaluated collect with respect to a product"""
  215. a = symbols('a')
  216. assert collect(a**2*(a**2 + 1), a**2, evaluate=False)[a**2] == (a**2 + 1)
  217. def test_collect_D():
  218. D = Derivative
  219. f = Function('f')
  220. x, a, b = symbols('x,a,b')
  221. fx = D(f(x), x)
  222. fxx = D(f(x), x, x)
  223. assert collect(a*fx + b*fx, fx) == (a + b)*fx
  224. assert collect(a*D(fx, x) + b*D(fx, x), fx) == (a + b)*D(fx, x)
  225. assert collect(a*fxx + b*fxx, fx) == (a + b)*D(fx, x)
  226. # issue 4784
  227. assert collect(5*f(x) + 3*fx, fx) == 5*f(x) + 3*fx
  228. assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x)) == \
  229. (x*f(x) + f(x))*D(f(x), x) + f(x)
  230. assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x), exact=True) == \
  231. (x*f(x) + f(x))*D(f(x), x) + f(x)
  232. assert collect(1/f(x) + 1/f(x)*diff(f(x), x) + x*diff(f(x), x)/f(x), f(x).diff(x), exact=True) == \
  233. (1/f(x) + x/f(x))*D(f(x), x) + 1/f(x)
  234. e = (1 + x*fx + fx)/f(x)
  235. assert collect(e.expand(), fx) == fx*(x/f(x) + 1/f(x)) + 1/f(x)
  236. def test_collect_func():
  237. f = ((x + a + 1)**3).expand()
  238. assert collect(f, x) == a**3 + 3*a**2 + 3*a + x**3 + x**2*(3*a + 3) + \
  239. x*(3*a**2 + 6*a + 3) + 1
  240. assert collect(f, x, factor) == x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + \
  241. (a + 1)**3
  242. assert collect(f, x, evaluate=False) == {
  243. S.One: a**3 + 3*a**2 + 3*a + 1,
  244. x: 3*a**2 + 6*a + 3, x**2: 3*a + 3,
  245. x**3: 1
  246. }
  247. assert collect(f, x, factor, evaluate=False) == {
  248. S.One: (a + 1)**3, x: 3*(a + 1)**2,
  249. x**2: umul(S(3), a + 1), x**3: 1}
  250. def test_collect_order():
  251. a, b, x, t = symbols('a,b,x,t')
  252. assert collect(t + t*x + t*x**2 + O(x**3), t) == t*(1 + x + x**2 + O(x**3))
  253. assert collect(t + t*x + x**2 + O(x**3), t) == \
  254. t*(1 + x + O(x**3)) + x**2 + O(x**3)
  255. f = a*x + b*x + c*x**2 + d*x**2 + O(x**3)
  256. g = x*(a + b) + x**2*(c + d) + O(x**3)
  257. assert collect(f, x) == g
  258. assert collect(f, x, distribute_order_term=False) == g
  259. f = sin(a + b).series(b, 0, 10)
  260. assert collect(f, [sin(a), cos(a)]) == \
  261. sin(a)*cos(b).series(b, 0, 10) + cos(a)*sin(b).series(b, 0, 10)
  262. assert collect(f, [sin(a), cos(a)], distribute_order_term=False) == \
  263. sin(a)*cos(b).series(b, 0, 10).removeO() + \
  264. cos(a)*sin(b).series(b, 0, 10).removeO() + O(b**10)
  265. def test_rcollect():
  266. assert rcollect((x**2*y + x*y + x + y)/(x + y), y) == \
  267. (x + y*(1 + x + x**2))/(x + y)
  268. assert rcollect(sqrt(-((x + 1)*(y + 1))), z) == sqrt(-((x + 1)*(y + 1)))
  269. def test_collect_D_0():
  270. D = Derivative
  271. f = Function('f')
  272. x, a, b = symbols('x,a,b')
  273. fxx = D(f(x), x, x)
  274. assert collect(a*fxx + b*fxx, fxx) == (a + b)*fxx
  275. def test_collect_Wild():
  276. """Collect with respect to functions with Wild argument"""
  277. a, b, x, y = symbols('a b x y')
  278. f = Function('f')
  279. w1 = Wild('.1')
  280. w2 = Wild('.2')
  281. assert collect(f(x) + a*f(x), f(w1)) == (1 + a)*f(x)
  282. assert collect(f(x, y) + a*f(x, y), f(w1)) == f(x, y) + a*f(x, y)
  283. assert collect(f(x, y) + a*f(x, y), f(w1, w2)) == (1 + a)*f(x, y)
  284. assert collect(f(x, y) + a*f(x, y), f(w1, w1)) == f(x, y) + a*f(x, y)
  285. assert collect(f(x, x) + a*f(x, x), f(w1, w1)) == (1 + a)*f(x, x)
  286. assert collect(a*(x + 1)**y + (x + 1)**y, w1**y) == (1 + a)*(x + 1)**y
  287. assert collect(a*(x + 1)**y + (x + 1)**y, w1**b) == \
  288. a*(x + 1)**y + (x + 1)**y
  289. assert collect(a*(x + 1)**y + (x + 1)**y, (x + 1)**w2) == \
  290. (1 + a)*(x + 1)**y
  291. assert collect(a*(x + 1)**y + (x + 1)**y, w1**w2) == (1 + a)*(x + 1)**y
  292. def test_collect_const():
  293. # coverage not provided by above tests
  294. assert collect_const(2*sqrt(3) + 4*a*sqrt(5)) == \
  295. 2*(2*sqrt(5)*a + sqrt(3)) # let the primitive reabsorb
  296. assert collect_const(2*sqrt(3) + 4*a*sqrt(5), sqrt(3)) == \
  297. 2*sqrt(3) + 4*a*sqrt(5)
  298. assert collect_const(sqrt(2)*(1 + sqrt(2)) + sqrt(3) + x*sqrt(2)) == \
  299. sqrt(2)*(x + 1 + sqrt(2)) + sqrt(3)
  300. # issue 5290
  301. assert collect_const(2*x + 2*y + 1, 2) == \
  302. collect_const(2*x + 2*y + 1) == \
  303. Add(S.One, Mul(2, x + y, evaluate=False), evaluate=False)
  304. assert collect_const(-y - z) == Mul(-1, y + z, evaluate=False)
  305. assert collect_const(2*x - 2*y - 2*z, 2) == \
  306. Mul(2, x - y - z, evaluate=False)
  307. assert collect_const(2*x - 2*y - 2*z, -2) == \
  308. _unevaluated_Add(2*x, Mul(-2, y + z, evaluate=False))
  309. # this is why the content_primitive is used
  310. eq = (sqrt(15 + 5*sqrt(2))*x + sqrt(3 + sqrt(2))*y)*2
  311. assert collect_sqrt(eq + 2) == \
  312. 2*sqrt(sqrt(2) + 3)*(sqrt(5)*x + y) + 2
  313. # issue 16296
  314. assert collect_const(a + b + x/2 + y/2) == a + b + Mul(S.Half, x + y, evaluate=False)
  315. def test_issue_13143():
  316. f = Function('f')
  317. fx = f(x).diff(x)
  318. e = f(x) + fx + f(x)*fx
  319. # collect function before derivative
  320. assert collect(e, Wild('w')) == f(x)*(fx + 1) + fx
  321. e = f(x) + f(x)*fx + x*fx*f(x)
  322. assert collect(e, fx) == (x*f(x) + f(x))*fx + f(x)
  323. assert collect(e, f(x)) == (x*fx + fx + 1)*f(x)
  324. e = f(x) + fx + f(x)*fx
  325. assert collect(e, [f(x), fx]) == f(x)*(1 + fx) + fx
  326. assert collect(e, [fx, f(x)]) == fx*(1 + f(x)) + f(x)
  327. def test_issue_6097():
  328. assert collect(a*y**(2.0*x) + b*y**(2.0*x), y**x) == (a + b)*(y**x)**2.0
  329. assert collect(a*2**(2.0*x) + b*2**(2.0*x), 2**x) == (a + b)*(2**x)**2.0
  330. def test_fraction_expand():
  331. eq = (x + y)*y/x
  332. assert eq.expand(frac=True) == fraction_expand(eq) == (x*y + y**2)/x
  333. assert eq.expand() == y + y**2/x
  334. def test_fraction():
  335. x, y, z = map(Symbol, 'xyz')
  336. A = Symbol('A', commutative=False)
  337. assert fraction(S.Half) == (1, 2)
  338. assert fraction(x) == (x, 1)
  339. assert fraction(1/x) == (1, x)
  340. assert fraction(x/y) == (x, y)
  341. assert fraction(x/2) == (x, 2)
  342. assert fraction(x*y/z) == (x*y, z)
  343. assert fraction(x/(y*z)) == (x, y*z)
  344. assert fraction(1/y**2) == (1, y**2)
  345. assert fraction(x/y**2) == (x, y**2)
  346. assert fraction((x**2 + 1)/y) == (x**2 + 1, y)
  347. assert fraction(x*(y + 1)/y**7) == (x*(y + 1), y**7)
  348. assert fraction(exp(-x), exact=True) == (exp(-x), 1)
  349. assert fraction((1/(x + y))/2, exact=True) == (1, Mul(2,(x + y), evaluate=False))
  350. assert fraction(x*A/y) == (x*A, y)
  351. assert fraction(x*A**-1/y) == (x*A**-1, y)
  352. n = symbols('n', negative=True)
  353. assert fraction(exp(n)) == (1, exp(-n))
  354. assert fraction(exp(-n)) == (exp(-n), 1)
  355. p = symbols('p', positive=True)
  356. assert fraction(exp(-p)*log(p), exact=True) == (exp(-p)*log(p), 1)
  357. m = Mul(1, 1, S.Half, evaluate=False)
  358. assert fraction(m) == (1, 2)
  359. assert fraction(m, exact=True) == (Mul(1, 1, evaluate=False), 2)
  360. m = Mul(1, 1, S.Half, S.Half, Pow(1, -1, evaluate=False), evaluate=False)
  361. assert fraction(m) == (1, 4)
  362. assert fraction(m, exact=True) == \
  363. (Mul(1, 1, evaluate=False), Mul(2, 2, 1, evaluate=False))
  364. def test_issue_5615():
  365. aA, Re, a, b, D = symbols('aA Re a b D')
  366. e = ((D**3*a + b*aA**3)/Re).expand()
  367. assert collect(e, [aA**3/Re, a]) == e
  368. def test_issue_5933():
  369. from sympy.geometry.polygon import (Polygon, RegularPolygon)
  370. from sympy.simplify.radsimp import denom
  371. x = Polygon(*RegularPolygon((0, 0), 1, 5).vertices).centroid.x
  372. assert abs(denom(x).n()) > 1e-12
  373. assert abs(denom(radsimp(x))) > 1e-12 # in case simplify didn't handle it
  374. def test_issue_14608():
  375. a, b = symbols('a b', commutative=False)
  376. x, y = symbols('x y')
  377. raises(AttributeError, lambda: collect(a*b + b*a, a))
  378. assert collect(x*y + y*(x+1), a) == x*y + y*(x+1)
  379. assert collect(x*y + y*(x+1) + a*b + b*a, y) == y*(2*x + 1) + a*b + b*a
  380. def test_collect_abs():
  381. s = abs(x) + abs(y)
  382. assert collect_abs(s) == s
  383. assert unchanged(Mul, abs(x), abs(y))
  384. ans = Abs(x*y)
  385. assert isinstance(ans, Abs)
  386. assert collect_abs(abs(x)*abs(y)) == ans
  387. assert collect_abs(1 + exp(abs(x)*abs(y))) == 1 + exp(ans)
  388. # See https://github.com/sympy/sympy/issues/12910
  389. p = Symbol('p', positive=True)
  390. assert collect_abs(p/abs(1-p)).is_commutative is True
  391. def test_issue_19149():
  392. eq = exp(3*x/4)
  393. assert collect(eq, exp(x)) == eq
  394. def test_issue_19719():
  395. a, b = symbols('a, b')
  396. expr = a**2 * (b + 1) + (7 + 1/b)/a
  397. collected = collect(expr, (a**2, 1/a), evaluate=False)
  398. # Would return {_Dummy_20**(-2): b + 1, 1/a: 7 + 1/b} without xreplace
  399. assert collected == {a**2: b + 1, 1/a: 7 + 1/b}
  400. def test_issue_21355():
  401. assert radsimp(1/(x + sqrt(x**2))) == 1/(x + sqrt(x**2))
  402. assert radsimp(1/(x - sqrt(x**2))) == 1/(x - sqrt(x**2))