test_codeprinter.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from sympy.printing.codeprinter import CodePrinter, PrintMethodNotImplementedError
  2. from sympy.core import symbols
  3. from sympy.core.symbol import Dummy
  4. from sympy.testing.pytest import raises
  5. from sympy import cos
  6. from sympy.utilities.lambdify import lambdify
  7. from math import cos as math_cos
  8. from sympy.printing.lambdarepr import LambdaPrinter
  9. def setup_test_printer(**kwargs):
  10. p = CodePrinter(settings=kwargs)
  11. p._not_supported = set()
  12. p._number_symbols = set()
  13. return p
  14. def test_print_Dummy():
  15. d = Dummy('d')
  16. p = setup_test_printer()
  17. assert p._print_Dummy(d) == "d_%i" % d.dummy_index
  18. def test_print_Symbol():
  19. x, y = symbols('x, if')
  20. p = setup_test_printer()
  21. assert p._print(x) == 'x'
  22. assert p._print(y) == 'if'
  23. p.reserved_words.update(['if'])
  24. assert p._print(y) == 'if_'
  25. p = setup_test_printer(error_on_reserved=True)
  26. p.reserved_words.update(['if'])
  27. with raises(ValueError):
  28. p._print(y)
  29. p = setup_test_printer(reserved_word_suffix='_He_Man')
  30. p.reserved_words.update(['if'])
  31. assert p._print(y) == 'if_He_Man'
  32. def test_lambdify_LaTeX_symbols_issue_23374():
  33. # Create symbols with Latex style names
  34. x1, x2 = symbols("x_{1} x_2")
  35. # Lambdify the function
  36. f1 = lambdify([x1, x2], cos(x1 ** 2 + x2 ** 2))
  37. # Test that the function works correctly (numerically)
  38. assert f1(1, 2) == math_cos(1 ** 2 + 2 ** 2)
  39. # Explicitly generate a custom printer to verify the naming convention
  40. p = LambdaPrinter()
  41. expr_str = p.doprint(cos(x1 ** 2 + x2 ** 2))
  42. assert 'x_1' in expr_str
  43. assert 'x_2' in expr_str
  44. def test_issue_15791():
  45. class CrashingCodePrinter(CodePrinter):
  46. def emptyPrinter(self, obj):
  47. raise NotImplementedError
  48. from sympy.matrices import (
  49. MutableSparseMatrix,
  50. ImmutableSparseMatrix,
  51. )
  52. c = CrashingCodePrinter()
  53. # these should not silently succeed
  54. with raises(PrintMethodNotImplementedError):
  55. c.doprint(ImmutableSparseMatrix(2, 2, {}))
  56. with raises(PrintMethodNotImplementedError):
  57. c.doprint(MutableSparseMatrix(2, 2, {}))