test_cupy.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. from sympy.concrete.summations import Sum
  2. from sympy.functions.elementary.exponential import log
  3. from sympy.functions.elementary.miscellaneous import sqrt
  4. from sympy.utilities.lambdify import lambdify
  5. from sympy.abc import x, i, a, b
  6. from sympy.codegen.numpy_nodes import logaddexp
  7. from sympy.printing.numpy import CuPyPrinter, _cupy_known_constants, _cupy_known_functions
  8. from sympy.testing.pytest import skip, raises
  9. from sympy.external import import_module
  10. cp = import_module('cupy')
  11. def test_cupy_print():
  12. prntr = CuPyPrinter()
  13. assert prntr.doprint(logaddexp(a, b)) == 'cupy.logaddexp(a, b)'
  14. assert prntr.doprint(sqrt(x)) == 'cupy.sqrt(x)'
  15. assert prntr.doprint(log(x)) == 'cupy.log(x)'
  16. assert prntr.doprint("acos(x)") == 'cupy.arccos(x)'
  17. assert prntr.doprint("exp(x)") == 'cupy.exp(x)'
  18. assert prntr.doprint("Abs(x)") == 'abs(x)'
  19. def test_not_cupy_print():
  20. prntr = CuPyPrinter()
  21. with raises(NotImplementedError):
  22. prntr.doprint("abcd(x)")
  23. def test_cupy_sum():
  24. if not cp:
  25. skip("CuPy not installed")
  26. s = Sum(x ** i, (i, a, b))
  27. f = lambdify((a, b, x), s, 'cupy')
  28. a_, b_ = 0, 10
  29. x_ = cp.linspace(-1, +1, 10)
  30. assert cp.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1)))
  31. s = Sum(i * x, (i, a, b))
  32. f = lambdify((a, b, x), s, 'numpy')
  33. a_, b_ = 0, 10
  34. x_ = cp.linspace(-1, +1, 10)
  35. assert cp.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1)))
  36. def test_cupy_known_funcs_consts():
  37. assert _cupy_known_constants['NaN'] == 'cupy.nan'
  38. assert _cupy_known_constants['EulerGamma'] == 'cupy.euler_gamma'
  39. assert _cupy_known_functions['acos'] == 'cupy.arccos'
  40. assert _cupy_known_functions['log'] == 'cupy.log'
  41. def test_cupy_print_methods():
  42. prntr = CuPyPrinter()
  43. assert hasattr(prntr, '_print_acos')
  44. assert hasattr(prntr, '_print_log')