tensorflow.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. import sympy.codegen
  2. import sympy.codegen.cfunctions
  3. from sympy.external.importtools import version_tuple
  4. from collections.abc import Iterable
  5. from sympy.core.mul import Mul
  6. from sympy.core.singleton import S
  7. from sympy.codegen.cfunctions import Sqrt
  8. from sympy.external import import_module
  9. from sympy.printing.precedence import PRECEDENCE
  10. from sympy.printing.pycode import AbstractPythonCodePrinter, ArrayPrinter
  11. import sympy
  12. tensorflow = import_module('tensorflow')
  13. class TensorflowPrinter(ArrayPrinter, AbstractPythonCodePrinter):
  14. """
  15. Tensorflow printer which handles vectorized piecewise functions,
  16. logical operators, max/min, and relational operators.
  17. """
  18. printmethod = "_tensorflowcode"
  19. mapping = {
  20. sympy.Abs: "tensorflow.math.abs",
  21. sympy.sign: "tensorflow.math.sign",
  22. # XXX May raise error for ints.
  23. sympy.ceiling: "tensorflow.math.ceil",
  24. sympy.floor: "tensorflow.math.floor",
  25. sympy.log: "tensorflow.math.log",
  26. sympy.exp: "tensorflow.math.exp",
  27. Sqrt: "tensorflow.math.sqrt",
  28. sympy.cos: "tensorflow.math.cos",
  29. sympy.acos: "tensorflow.math.acos",
  30. sympy.sin: "tensorflow.math.sin",
  31. sympy.asin: "tensorflow.math.asin",
  32. sympy.tan: "tensorflow.math.tan",
  33. sympy.atan: "tensorflow.math.atan",
  34. sympy.atan2: "tensorflow.math.atan2",
  35. # XXX Also may give NaN for complex results.
  36. sympy.cosh: "tensorflow.math.cosh",
  37. sympy.acosh: "tensorflow.math.acosh",
  38. sympy.sinh: "tensorflow.math.sinh",
  39. sympy.asinh: "tensorflow.math.asinh",
  40. sympy.tanh: "tensorflow.math.tanh",
  41. sympy.atanh: "tensorflow.math.atanh",
  42. sympy.re: "tensorflow.math.real",
  43. sympy.im: "tensorflow.math.imag",
  44. sympy.arg: "tensorflow.math.angle",
  45. # XXX May raise error for ints and complexes
  46. sympy.erf: "tensorflow.math.erf",
  47. sympy.loggamma: "tensorflow.math.lgamma",
  48. sympy.Eq: "tensorflow.math.equal",
  49. sympy.Ne: "tensorflow.math.not_equal",
  50. sympy.StrictGreaterThan: "tensorflow.math.greater",
  51. sympy.StrictLessThan: "tensorflow.math.less",
  52. sympy.LessThan: "tensorflow.math.less_equal",
  53. sympy.GreaterThan: "tensorflow.math.greater_equal",
  54. sympy.And: "tensorflow.math.logical_and",
  55. sympy.Or: "tensorflow.math.logical_or",
  56. sympy.Not: "tensorflow.math.logical_not",
  57. sympy.Max: "tensorflow.math.maximum",
  58. sympy.Min: "tensorflow.math.minimum",
  59. # Matrices
  60. sympy.MatAdd: "tensorflow.math.add",
  61. sympy.HadamardProduct: "tensorflow.math.multiply",
  62. sympy.Trace: "tensorflow.linalg.trace",
  63. # XXX May raise error for integer matrices.
  64. sympy.Determinant : "tensorflow.linalg.det",
  65. }
  66. _default_settings = dict(
  67. AbstractPythonCodePrinter._default_settings,
  68. tensorflow_version=None
  69. )
  70. def __init__(self, settings=None):
  71. super().__init__(settings)
  72. version = self._settings['tensorflow_version']
  73. if version is None and tensorflow:
  74. version = tensorflow.__version__
  75. self.tensorflow_version = version
  76. def _print_Function(self, expr):
  77. op = self.mapping.get(type(expr), None)
  78. if op is None:
  79. return super()._print_Basic(expr)
  80. children = [self._print(arg) for arg in expr.args]
  81. if len(children) == 1:
  82. return "%s(%s)" % (
  83. self._module_format(op),
  84. children[0]
  85. )
  86. else:
  87. return self._expand_fold_binary_op(op, children)
  88. _print_Expr = _print_Function
  89. _print_Application = _print_Function
  90. _print_MatrixExpr = _print_Function
  91. # TODO: a better class structure would avoid this mess:
  92. _print_Relational = _print_Function
  93. _print_Not = _print_Function
  94. _print_And = _print_Function
  95. _print_Or = _print_Function
  96. _print_HadamardProduct = _print_Function
  97. _print_Trace = _print_Function
  98. _print_Determinant = _print_Function
  99. def _print_Inverse(self, expr):
  100. op = self._module_format('tensorflow.linalg.inv')
  101. return "{}({})".format(op, self._print(expr.arg))
  102. def _print_Transpose(self, expr):
  103. version = self.tensorflow_version
  104. if version and version_tuple(version) < version_tuple('1.14'):
  105. op = self._module_format('tensorflow.matrix_transpose')
  106. else:
  107. op = self._module_format('tensorflow.linalg.matrix_transpose')
  108. return "{}({})".format(op, self._print(expr.arg))
  109. def _print_Derivative(self, expr):
  110. variables = expr.variables
  111. if any(isinstance(i, Iterable) for i in variables):
  112. raise NotImplementedError("derivation by multiple variables is not supported")
  113. def unfold(expr, args):
  114. if not args:
  115. return self._print(expr)
  116. return "%s(%s, %s)[0]" % (
  117. self._module_format("tensorflow.gradients"),
  118. unfold(expr, args[:-1]),
  119. self._print(args[-1]),
  120. )
  121. return unfold(expr.expr, variables)
  122. def _print_Piecewise(self, expr):
  123. version = self.tensorflow_version
  124. if version and version_tuple(version) < version_tuple('1.0'):
  125. tensorflow_piecewise = "tensorflow.select"
  126. else:
  127. tensorflow_piecewise = "tensorflow.where"
  128. from sympy.functions.elementary.piecewise import Piecewise
  129. e, cond = expr.args[0].args
  130. if len(expr.args) == 1:
  131. return '{}({}, {}, {})'.format(
  132. self._module_format(tensorflow_piecewise),
  133. self._print(cond),
  134. self._print(e),
  135. 0)
  136. return '{}({}, {}, {})'.format(
  137. self._module_format(tensorflow_piecewise),
  138. self._print(cond),
  139. self._print(e),
  140. self._print(Piecewise(*expr.args[1:])))
  141. def _print_Pow(self, expr):
  142. # XXX May raise error for
  143. # int**float or int**complex or float**complex
  144. base, exp = expr.args
  145. if expr.exp == S.Half:
  146. return "{}({})".format(
  147. self._module_format("tensorflow.math.sqrt"), self._print(base))
  148. return "{}({}, {})".format(
  149. self._module_format("tensorflow.math.pow"),
  150. self._print(base), self._print(exp))
  151. def _print_MatrixBase(self, expr):
  152. tensorflow_f = "tensorflow.Variable" if expr.free_symbols else "tensorflow.constant"
  153. data = "["+", ".join(["["+", ".join([self._print(j) for j in i])+"]" for i in expr.tolist()])+"]"
  154. return "%s(%s)" % (
  155. self._module_format(tensorflow_f),
  156. data,
  157. )
  158. def _print_MatMul(self, expr):
  159. from sympy.matrices.expressions import MatrixExpr
  160. mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)]
  161. args = [arg for arg in expr.args if arg not in mat_args]
  162. if args:
  163. return "%s*%s" % (
  164. self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]),
  165. self._expand_fold_binary_op(
  166. "tensorflow.linalg.matmul", mat_args)
  167. )
  168. else:
  169. return self._expand_fold_binary_op(
  170. "tensorflow.linalg.matmul", mat_args)
  171. def _print_MatPow(self, expr):
  172. return self._expand_fold_binary_op(
  173. "tensorflow.linalg.matmul", [expr.base]*expr.exp)
  174. def _print_CodeBlock(self, expr):
  175. # TODO: is this necessary?
  176. ret = []
  177. for subexpr in expr.args:
  178. ret.append(self._print(subexpr))
  179. return "\n".join(ret)
  180. def _print_isnan(self, exp):
  181. return f'tensorflow.math.is_nan({self._print(*exp.args)})'
  182. def _print_isinf(self, exp):
  183. return f'tensorflow.math.is_inf({self._print(*exp.args)})'
  184. _module = "tensorflow"
  185. _einsum = "linalg.einsum"
  186. _add = "math.add"
  187. _transpose = "transpose"
  188. _ones = "ones"
  189. _zeros = "zeros"
  190. def tensorflow_code(expr, **settings):
  191. printer = TensorflowPrinter(settings)
  192. return printer.doprint(expr)