python.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import keyword as kw
  2. import sympy
  3. from .repr import ReprPrinter
  4. from .str import StrPrinter
  5. # A list of classes that should be printed using StrPrinter
  6. STRPRINT = ("Add", "Infinity", "Integer", "Mul", "NegativeInfinity", "Pow")
  7. class PythonPrinter(ReprPrinter, StrPrinter):
  8. """A printer which converts an expression into its Python interpretation."""
  9. def __init__(self, settings=None):
  10. super().__init__(settings)
  11. self.symbols = []
  12. self.functions = []
  13. # Create print methods for classes that should use StrPrinter instead
  14. # of ReprPrinter.
  15. for name in STRPRINT:
  16. f_name = "_print_%s" % name
  17. f = getattr(StrPrinter, f_name)
  18. setattr(PythonPrinter, f_name, f)
  19. def _print_Function(self, expr):
  20. func = expr.func.__name__
  21. if not hasattr(sympy, func) and func not in self.functions:
  22. self.functions.append(func)
  23. return StrPrinter._print_Function(self, expr)
  24. # procedure (!) for defining symbols which have be defined in print_python()
  25. def _print_Symbol(self, expr):
  26. symbol = self._str(expr)
  27. if symbol not in self.symbols:
  28. self.symbols.append(symbol)
  29. return StrPrinter._print_Symbol(self, expr)
  30. def _print_module(self, expr):
  31. raise ValueError('Modules in the expression are unacceptable')
  32. def python(expr, **settings):
  33. """Return Python interpretation of passed expression
  34. (can be passed to the exec() function without any modifications)"""
  35. printer = PythonPrinter(settings)
  36. exprp = printer.doprint(expr)
  37. result = ''
  38. # Returning found symbols and functions
  39. renamings = {}
  40. for symbolname in printer.symbols:
  41. # Remove curly braces from subscripted variables
  42. if '{' in symbolname:
  43. newsymbolname = symbolname.replace('{', '').replace('}', '')
  44. renamings[sympy.Symbol(symbolname)] = newsymbolname
  45. else:
  46. newsymbolname = symbolname
  47. # Escape symbol names that are reserved Python keywords
  48. if kw.iskeyword(newsymbolname):
  49. while True:
  50. newsymbolname += "_"
  51. if (newsymbolname not in printer.symbols and
  52. newsymbolname not in printer.functions):
  53. renamings[sympy.Symbol(
  54. symbolname)] = sympy.Symbol(newsymbolname)
  55. break
  56. result += newsymbolname + ' = Symbol(\'' + symbolname + '\')\n'
  57. for functionname in printer.functions:
  58. newfunctionname = functionname
  59. # Escape function names that are reserved Python keywords
  60. if kw.iskeyword(newfunctionname):
  61. while True:
  62. newfunctionname += "_"
  63. if (newfunctionname not in printer.symbols and
  64. newfunctionname not in printer.functions):
  65. renamings[sympy.Function(
  66. functionname)] = sympy.Function(newfunctionname)
  67. break
  68. result += newfunctionname + ' = Function(\'' + functionname + '\')\n'
  69. if renamings:
  70. exprp = expr.subs(renamings)
  71. result += 'e = ' + printer._str(exprp)
  72. return result
  73. def print_python(expr, **settings):
  74. """Print output of python() function"""
  75. print(python(expr, **settings))