cxx.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. """
  2. C++ code printer
  3. """
  4. from itertools import chain
  5. from sympy.codegen.ast import Type, none
  6. from .codeprinter import requires
  7. from .c import C89CodePrinter, C99CodePrinter
  8. # These are defined in the other file so we can avoid importing sympy.codegen
  9. # from the top-level 'import sympy'. Export them here as well.
  10. from sympy.printing.codeprinter import cxxcode # noqa:F401
  11. # from https://en.cppreference.com/w/cpp/keyword
  12. reserved = {
  13. 'C++98': [
  14. 'and', 'and_eq', 'asm', 'auto', 'bitand', 'bitor', 'bool', 'break',
  15. 'case', 'catch,', 'char', 'class', 'compl', 'const', 'const_cast',
  16. 'continue', 'default', 'delete', 'do', 'double', 'dynamic_cast',
  17. 'else', 'enum', 'explicit', 'export', 'extern', 'false', 'float',
  18. 'for', 'friend', 'goto', 'if', 'inline', 'int', 'long', 'mutable',
  19. 'namespace', 'new', 'not', 'not_eq', 'operator', 'or', 'or_eq',
  20. 'private', 'protected', 'public', 'register', 'reinterpret_cast',
  21. 'return', 'short', 'signed', 'sizeof', 'static', 'static_cast',
  22. 'struct', 'switch', 'template', 'this', 'throw', 'true', 'try',
  23. 'typedef', 'typeid', 'typename', 'union', 'unsigned', 'using',
  24. 'virtual', 'void', 'volatile', 'wchar_t', 'while', 'xor', 'xor_eq'
  25. ]
  26. }
  27. reserved['C++11'] = reserved['C++98'][:] + [
  28. 'alignas', 'alignof', 'char16_t', 'char32_t', 'constexpr', 'decltype',
  29. 'noexcept', 'nullptr', 'static_assert', 'thread_local'
  30. ]
  31. reserved['C++17'] = reserved['C++11'][:]
  32. reserved['C++17'].remove('register')
  33. # TM TS: atomic_cancel, atomic_commit, atomic_noexcept, synchronized
  34. # concepts TS: concept, requires
  35. # module TS: import, module
  36. _math_functions = {
  37. 'C++98': {
  38. 'Mod': 'fmod',
  39. 'ceiling': 'ceil',
  40. },
  41. 'C++11': {
  42. 'gamma': 'tgamma',
  43. },
  44. 'C++17': {
  45. 'beta': 'beta',
  46. 'Ei': 'expint',
  47. 'zeta': 'riemann_zeta',
  48. }
  49. }
  50. # from https://en.cppreference.com/w/cpp/header/cmath
  51. for k in ('Abs', 'exp', 'log', 'log10', 'sqrt', 'sin', 'cos', 'tan', # 'Pow'
  52. 'asin', 'acos', 'atan', 'atan2', 'sinh', 'cosh', 'tanh', 'floor'):
  53. _math_functions['C++98'][k] = k.lower()
  54. for k in ('asinh', 'acosh', 'atanh', 'erf', 'erfc'):
  55. _math_functions['C++11'][k] = k.lower()
  56. def _attach_print_method(cls, sympy_name, func_name):
  57. meth_name = '_print_%s' % sympy_name
  58. if hasattr(cls, meth_name):
  59. raise ValueError("Edit method (or subclass) instead of overwriting.")
  60. def _print_method(self, expr):
  61. return '{}{}({})'.format(self._ns, func_name, ', '.join(map(self._print, expr.args)))
  62. _print_method.__doc__ = "Prints code for %s" % k
  63. setattr(cls, meth_name, _print_method)
  64. def _attach_print_methods(cls, cont):
  65. for sympy_name, cxx_name in cont[cls.standard].items():
  66. _attach_print_method(cls, sympy_name, cxx_name)
  67. class _CXXCodePrinterBase:
  68. printmethod = "_cxxcode"
  69. language = 'C++'
  70. _ns = 'std::' # namespace
  71. def __init__(self, settings=None):
  72. super().__init__(settings or {})
  73. @requires(headers={'algorithm'})
  74. def _print_Max(self, expr):
  75. from sympy.functions.elementary.miscellaneous import Max
  76. if len(expr.args) == 1:
  77. return self._print(expr.args[0])
  78. return "%smax(%s, %s)" % (self._ns, self._print(expr.args[0]),
  79. self._print(Max(*expr.args[1:])))
  80. @requires(headers={'algorithm'})
  81. def _print_Min(self, expr):
  82. from sympy.functions.elementary.miscellaneous import Min
  83. if len(expr.args) == 1:
  84. return self._print(expr.args[0])
  85. return "%smin(%s, %s)" % (self._ns, self._print(expr.args[0]),
  86. self._print(Min(*expr.args[1:])))
  87. def _print_using(self, expr):
  88. if expr.alias == none:
  89. return 'using %s' % expr.type
  90. else:
  91. raise ValueError("C++98 does not support type aliases")
  92. def _print_Raise(self, rs):
  93. arg, = rs.args
  94. return 'throw %s' % self._print(arg)
  95. @requires(headers={'stdexcept'})
  96. def _print_RuntimeError_(self, re):
  97. message, = re.args
  98. return "%sruntime_error(%s)" % (self._ns, self._print(message))
  99. class CXX98CodePrinter(_CXXCodePrinterBase, C89CodePrinter):
  100. standard = 'C++98'
  101. reserved_words = set(reserved['C++98'])
  102. # _attach_print_methods(CXX98CodePrinter, _math_functions)
  103. class CXX11CodePrinter(_CXXCodePrinterBase, C99CodePrinter):
  104. standard = 'C++11'
  105. reserved_words = set(reserved['C++11'])
  106. type_mappings = dict(chain(
  107. CXX98CodePrinter.type_mappings.items(),
  108. {
  109. Type('int8'): ('int8_t', {'cstdint'}),
  110. Type('int16'): ('int16_t', {'cstdint'}),
  111. Type('int32'): ('int32_t', {'cstdint'}),
  112. Type('int64'): ('int64_t', {'cstdint'}),
  113. Type('uint8'): ('uint8_t', {'cstdint'}),
  114. Type('uint16'): ('uint16_t', {'cstdint'}),
  115. Type('uint32'): ('uint32_t', {'cstdint'}),
  116. Type('uint64'): ('uint64_t', {'cstdint'}),
  117. Type('complex64'): ('std::complex<float>', {'complex'}),
  118. Type('complex128'): ('std::complex<double>', {'complex'}),
  119. Type('bool'): ('bool', None),
  120. }.items()
  121. ))
  122. def _print_using(self, expr):
  123. if expr.alias == none:
  124. return super()._print_using(expr)
  125. else:
  126. return 'using %(alias)s = %(type)s' % expr.kwargs(apply=self._print)
  127. # _attach_print_methods(CXX11CodePrinter, _math_functions)
  128. class CXX17CodePrinter(_CXXCodePrinterBase, C99CodePrinter):
  129. standard = 'C++17'
  130. reserved_words = set(reserved['C++17'])
  131. _kf = dict(C99CodePrinter._kf, **_math_functions['C++17'])
  132. def _print_beta(self, expr):
  133. return self._print_math_func(expr)
  134. def _print_Ei(self, expr):
  135. return self._print_math_func(expr)
  136. def _print_zeta(self, expr):
  137. return self._print_math_func(expr)
  138. # _attach_print_methods(CXX17CodePrinter, _math_functions)
  139. cxx_code_printers = {
  140. 'c++98': CXX98CodePrinter,
  141. 'c++11': CXX11CodePrinter,
  142. 'c++17': CXX17CodePrinter
  143. }