test_numpy_nodes.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. from itertools import product
  2. from sympy.core.singleton import S
  3. from sympy.core.symbol import symbols
  4. from sympy.functions.elementary.exponential import (exp, log)
  5. from sympy.functions.elementary.miscellaneous import Max, Min
  6. from sympy.printing.repr import srepr
  7. from sympy.codegen.numpy_nodes import logaddexp, logaddexp2, minimum, maximum, amax, amin
  8. from sympy.testing.pytest import raises
  9. x, y, z = symbols('x y z')
  10. def test_logaddexp():
  11. lae_xy = logaddexp(x, y)
  12. ref_xy = log(exp(x) + exp(y))
  13. for wrt, deriv_order in product([x, y, z], range(3)):
  14. assert (
  15. lae_xy.diff(wrt, deriv_order) -
  16. ref_xy.diff(wrt, deriv_order)
  17. ).rewrite(log).simplify() == 0
  18. one_third_e = 1*exp(1)/3
  19. two_thirds_e = 2*exp(1)/3
  20. logThirdE = log(one_third_e)
  21. logTwoThirdsE = log(two_thirds_e)
  22. lae_sum_to_e = logaddexp(logThirdE, logTwoThirdsE)
  23. assert lae_sum_to_e.rewrite(log) == 1
  24. assert lae_sum_to_e.simplify() == 1
  25. was = logaddexp(2, 3)
  26. assert srepr(was) == srepr(was.simplify()) # cannot simplify with 2, 3
  27. def test_logaddexp2():
  28. lae2_xy = logaddexp2(x, y)
  29. ref2_xy = log(2**x + 2**y)/log(2)
  30. for wrt, deriv_order in product([x, y, z], range(3)):
  31. assert (
  32. lae2_xy.diff(wrt, deriv_order) -
  33. ref2_xy.diff(wrt, deriv_order)
  34. ).rewrite(log).cancel() == 0
  35. def lb(x):
  36. return log(x)/log(2)
  37. two_thirds = S.One*2/3
  38. four_thirds = 2*two_thirds
  39. lbTwoThirds = lb(two_thirds)
  40. lbFourThirds = lb(four_thirds)
  41. lae2_sum_to_2 = logaddexp2(lbTwoThirds, lbFourThirds)
  42. assert lae2_sum_to_2.rewrite(log) == 1
  43. assert lae2_sum_to_2.simplify() == 1
  44. was = logaddexp2(x, y)
  45. assert srepr(was) == srepr(was.simplify()) # cannot simplify with x, y
  46. def test_minimum_maximum():
  47. for MM, mm in zip([Min, Max], [minimum, maximum]):
  48. ref = MM(x, y, z)
  49. m = mm(x, y, z)
  50. assert m != ref
  51. assert m.rewrite(MM) == ref
  52. def test_amin_amax():
  53. for am in [amin, amax]:
  54. assert am(x).array == x
  55. assert am(x).axis == None
  56. assert am(x, axis=3).axis == 3
  57. with raises(ValueError):
  58. am(x, y, z)