theanocode.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571
  1. """
  2. .. deprecated:: 1.8
  3. ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
  4. Aesara. Use ``sympy.printing.aesaracode`` instead. See
  5. :ref:`theanocode-deprecated` for more information.
  6. """
  7. from __future__ import annotations
  8. import math
  9. from typing import Any
  10. from sympy.external import import_module
  11. from sympy.printing.printer import Printer
  12. from sympy.utilities.iterables import is_sequence
  13. import sympy
  14. from functools import partial
  15. from sympy.utilities.decorator import doctest_depends_on
  16. from sympy.utilities.exceptions import sympy_deprecation_warning
  17. __doctest_requires__ = {('theano_function',): ['theano']}
  18. theano = import_module('theano')
  19. if theano:
  20. ts = theano.scalar
  21. tt = theano.tensor
  22. from theano.sandbox import linalg as tlinalg
  23. mapping = {
  24. sympy.Add: tt.add,
  25. sympy.Mul: tt.mul,
  26. sympy.Abs: tt.abs_,
  27. sympy.sign: tt.sgn,
  28. sympy.ceiling: tt.ceil,
  29. sympy.floor: tt.floor,
  30. sympy.log: tt.log,
  31. sympy.exp: tt.exp,
  32. sympy.sqrt: tt.sqrt,
  33. sympy.cos: tt.cos,
  34. sympy.acos: tt.arccos,
  35. sympy.sin: tt.sin,
  36. sympy.asin: tt.arcsin,
  37. sympy.tan: tt.tan,
  38. sympy.atan: tt.arctan,
  39. sympy.atan2: tt.arctan2,
  40. sympy.cosh: tt.cosh,
  41. sympy.acosh: tt.arccosh,
  42. sympy.sinh: tt.sinh,
  43. sympy.asinh: tt.arcsinh,
  44. sympy.tanh: tt.tanh,
  45. sympy.atanh: tt.arctanh,
  46. sympy.re: tt.real,
  47. sympy.im: tt.imag,
  48. sympy.arg: tt.angle,
  49. sympy.erf: tt.erf,
  50. sympy.gamma: tt.gamma,
  51. sympy.loggamma: tt.gammaln,
  52. sympy.Pow: tt.pow,
  53. sympy.Eq: tt.eq,
  54. sympy.StrictGreaterThan: tt.gt,
  55. sympy.StrictLessThan: tt.lt,
  56. sympy.LessThan: tt.le,
  57. sympy.GreaterThan: tt.ge,
  58. sympy.And: tt.and_,
  59. sympy.Or: tt.or_,
  60. sympy.Max: tt.maximum, # SymPy accept >2 inputs, Theano only 2
  61. sympy.Min: tt.minimum, # SymPy accept >2 inputs, Theano only 2
  62. sympy.conjugate: tt.conj,
  63. sympy.core.numbers.ImaginaryUnit: lambda:tt.complex(0,1),
  64. # Matrices
  65. sympy.MatAdd: tt.Elemwise(ts.add),
  66. sympy.HadamardProduct: tt.Elemwise(ts.mul),
  67. sympy.Trace: tlinalg.trace,
  68. sympy.Determinant : tlinalg.det,
  69. sympy.Inverse: tlinalg.matrix_inverse,
  70. sympy.Transpose: tt.DimShuffle((False, False), [1, 0]),
  71. }
  72. class TheanoPrinter(Printer):
  73. """ Code printer which creates Theano symbolic expression graphs.
  74. Parameters
  75. ==========
  76. cache : dict
  77. Cache dictionary to use. If None (default) will use
  78. the global cache. To create a printer which does not depend on or alter
  79. global state pass an empty dictionary. Note: the dictionary is not
  80. copied on initialization of the printer and will be updated in-place,
  81. so using the same dict object when creating multiple printers or making
  82. multiple calls to :func:`.theano_code` or :func:`.theano_function` means
  83. the cache is shared between all these applications.
  84. Attributes
  85. ==========
  86. cache : dict
  87. A cache of Theano variables which have been created for SymPy
  88. symbol-like objects (e.g. :class:`sympy.core.symbol.Symbol` or
  89. :class:`sympy.matrices.expressions.MatrixSymbol`). This is used to
  90. ensure that all references to a given symbol in an expression (or
  91. multiple expressions) are printed as the same Theano variable, which is
  92. created only once. Symbols are differentiated only by name and type. The
  93. format of the cache's contents should be considered opaque to the user.
  94. """
  95. printmethod = "_theano"
  96. def __init__(self, *args, **kwargs):
  97. self.cache = kwargs.pop('cache', {})
  98. super().__init__(*args, **kwargs)
  99. def _get_key(self, s, name=None, dtype=None, broadcastable=None):
  100. """ Get the cache key for a SymPy object.
  101. Parameters
  102. ==========
  103. s : sympy.core.basic.Basic
  104. SymPy object to get key for.
  105. name : str
  106. Name of object, if it does not have a ``name`` attribute.
  107. """
  108. if name is None:
  109. name = s.name
  110. return (name, type(s), s.args, dtype, broadcastable)
  111. def _get_or_create(self, s, name=None, dtype=None, broadcastable=None):
  112. """
  113. Get the Theano variable for a SymPy symbol from the cache, or create it
  114. if it does not exist.
  115. """
  116. # Defaults
  117. if name is None:
  118. name = s.name
  119. if dtype is None:
  120. dtype = 'floatX'
  121. if broadcastable is None:
  122. broadcastable = ()
  123. key = self._get_key(s, name, dtype=dtype, broadcastable=broadcastable)
  124. if key in self.cache:
  125. return self.cache[key]
  126. value = tt.tensor(name=name, dtype=dtype, broadcastable=broadcastable)
  127. self.cache[key] = value
  128. return value
  129. def _print_Symbol(self, s, **kwargs):
  130. dtype = kwargs.get('dtypes', {}).get(s)
  131. bc = kwargs.get('broadcastables', {}).get(s)
  132. return self._get_or_create(s, dtype=dtype, broadcastable=bc)
  133. def _print_AppliedUndef(self, s, **kwargs):
  134. name = str(type(s)) + '_' + str(s.args[0])
  135. dtype = kwargs.get('dtypes', {}).get(s)
  136. bc = kwargs.get('broadcastables', {}).get(s)
  137. return self._get_or_create(s, name=name, dtype=dtype, broadcastable=bc)
  138. def _print_Basic(self, expr, **kwargs):
  139. op = mapping[type(expr)]
  140. children = [self._print(arg, **kwargs) for arg in expr.args]
  141. return op(*children)
  142. def _print_Number(self, n, **kwargs):
  143. # Integers already taken care of below, interpret as float
  144. return float(n.evalf())
  145. def _print_MatrixSymbol(self, X, **kwargs):
  146. dtype = kwargs.get('dtypes', {}).get(X)
  147. return self._get_or_create(X, dtype=dtype, broadcastable=(None, None))
  148. def _print_DenseMatrix(self, X, **kwargs):
  149. if not hasattr(tt, 'stacklists'):
  150. raise NotImplementedError(
  151. "Matrix translation not yet supported in this version of Theano")
  152. return tt.stacklists([
  153. [self._print(arg, **kwargs) for arg in L]
  154. for L in X.tolist()
  155. ])
  156. _print_ImmutableMatrix = _print_ImmutableDenseMatrix = _print_DenseMatrix
  157. def _print_MatMul(self, expr, **kwargs):
  158. children = [self._print(arg, **kwargs) for arg in expr.args]
  159. result = children[0]
  160. for child in children[1:]:
  161. result = tt.dot(result, child)
  162. return result
  163. def _print_MatPow(self, expr, **kwargs):
  164. children = [self._print(arg, **kwargs) for arg in expr.args]
  165. result = 1
  166. if isinstance(children[1], int) and children[1] > 0:
  167. for i in range(children[1]):
  168. result = tt.dot(result, children[0])
  169. else:
  170. raise NotImplementedError('''Only non-negative integer
  171. powers of matrices can be handled by Theano at the moment''')
  172. return result
  173. def _print_MatrixSlice(self, expr, **kwargs):
  174. parent = self._print(expr.parent, **kwargs)
  175. rowslice = self._print(slice(*expr.rowslice), **kwargs)
  176. colslice = self._print(slice(*expr.colslice), **kwargs)
  177. return parent[rowslice, colslice]
  178. def _print_BlockMatrix(self, expr, **kwargs):
  179. nrows, ncols = expr.blocks.shape
  180. blocks = [[self._print(expr.blocks[r, c], **kwargs)
  181. for c in range(ncols)]
  182. for r in range(nrows)]
  183. return tt.join(0, *[tt.join(1, *row) for row in blocks])
  184. def _print_slice(self, expr, **kwargs):
  185. return slice(*[self._print(i, **kwargs)
  186. if isinstance(i, sympy.Basic) else i
  187. for i in (expr.start, expr.stop, expr.step)])
  188. def _print_Pi(self, expr, **kwargs):
  189. return math.pi
  190. def _print_Exp1(self, expr, **kwargs):
  191. return ts.exp(1)
  192. def _print_Piecewise(self, expr, **kwargs):
  193. import numpy as np
  194. e, cond = expr.args[0].args # First condition and corresponding value
  195. # Print conditional expression and value for first condition
  196. p_cond = self._print(cond, **kwargs)
  197. p_e = self._print(e, **kwargs)
  198. # One condition only
  199. if len(expr.args) == 1:
  200. # Return value if condition else NaN
  201. return tt.switch(p_cond, p_e, np.nan)
  202. # Return value_1 if condition_1 else evaluate remaining conditions
  203. p_remaining = self._print(sympy.Piecewise(*expr.args[1:]), **kwargs)
  204. return tt.switch(p_cond, p_e, p_remaining)
  205. def _print_Rational(self, expr, **kwargs):
  206. return tt.true_div(self._print(expr.p, **kwargs),
  207. self._print(expr.q, **kwargs))
  208. def _print_Integer(self, expr, **kwargs):
  209. return expr.p
  210. def _print_factorial(self, expr, **kwargs):
  211. return self._print(sympy.gamma(expr.args[0] + 1), **kwargs)
  212. def _print_Derivative(self, deriv, **kwargs):
  213. rv = self._print(deriv.expr, **kwargs)
  214. for var in deriv.variables:
  215. var = self._print(var, **kwargs)
  216. rv = tt.Rop(rv, var, tt.ones_like(var))
  217. return rv
  218. def emptyPrinter(self, expr):
  219. return expr
  220. def doprint(self, expr, dtypes=None, broadcastables=None):
  221. """ Convert a SymPy expression to a Theano graph variable.
  222. The ``dtypes`` and ``broadcastables`` arguments are used to specify the
  223. data type, dimension, and broadcasting behavior of the Theano variables
  224. corresponding to the free symbols in ``expr``. Each is a mapping from
  225. SymPy symbols to the value of the corresponding argument to
  226. ``theano.tensor.Tensor``.
  227. See the corresponding `documentation page`__ for more information on
  228. broadcasting in Theano.
  229. .. __: http://deeplearning.net/software/theano/tutorial/broadcasting.html
  230. Parameters
  231. ==========
  232. expr : sympy.core.expr.Expr
  233. SymPy expression to print.
  234. dtypes : dict
  235. Mapping from SymPy symbols to Theano datatypes to use when creating
  236. new Theano variables for those symbols. Corresponds to the ``dtype``
  237. argument to ``theano.tensor.Tensor``. Defaults to ``'floatX'``
  238. for symbols not included in the mapping.
  239. broadcastables : dict
  240. Mapping from SymPy symbols to the value of the ``broadcastable``
  241. argument to ``theano.tensor.Tensor`` to use when creating Theano
  242. variables for those symbols. Defaults to the empty tuple for symbols
  243. not included in the mapping (resulting in a scalar).
  244. Returns
  245. =======
  246. theano.gof.graph.Variable
  247. A variable corresponding to the expression's value in a Theano
  248. symbolic expression graph.
  249. """
  250. if dtypes is None:
  251. dtypes = {}
  252. if broadcastables is None:
  253. broadcastables = {}
  254. return self._print(expr, dtypes=dtypes, broadcastables=broadcastables)
  255. global_cache: dict[Any, Any] = {}
  256. def theano_code(expr, cache=None, **kwargs):
  257. """
  258. Convert a SymPy expression into a Theano graph variable.
  259. .. deprecated:: 1.8
  260. ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
  261. Aesara. Use ``sympy.printing.aesaracode`` instead. See
  262. :ref:`theanocode-deprecated` for more information.
  263. Parameters
  264. ==========
  265. expr : sympy.core.expr.Expr
  266. SymPy expression object to convert.
  267. cache : dict
  268. Cached Theano variables (see :class:`TheanoPrinter.cache
  269. <TheanoPrinter>`). Defaults to the module-level global cache.
  270. dtypes : dict
  271. Passed to :meth:`.TheanoPrinter.doprint`.
  272. broadcastables : dict
  273. Passed to :meth:`.TheanoPrinter.doprint`.
  274. Returns
  275. =======
  276. theano.gof.graph.Variable
  277. A variable corresponding to the expression's value in a Theano symbolic
  278. expression graph.
  279. """
  280. sympy_deprecation_warning(
  281. """
  282. sympy.printing.theanocode is deprecated. Theano has been renamed to
  283. Aesara. Use sympy.printing.aesaracode instead.""",
  284. deprecated_since_version="1.8",
  285. active_deprecations_target='theanocode-deprecated')
  286. if not theano:
  287. raise ImportError("theano is required for theano_code")
  288. if cache is None:
  289. cache = global_cache
  290. return TheanoPrinter(cache=cache, settings={}).doprint(expr, **kwargs)
  291. def dim_handling(inputs, dim=None, dims=None, broadcastables=None):
  292. r"""
  293. Get value of ``broadcastables`` argument to :func:`.theano_code` from
  294. keyword arguments to :func:`.theano_function`.
  295. Included for backwards compatibility.
  296. Parameters
  297. ==========
  298. inputs
  299. Sequence of input symbols.
  300. dim : int
  301. Common number of dimensions for all inputs. Overrides other arguments
  302. if given.
  303. dims : dict
  304. Mapping from input symbols to number of dimensions. Overrides
  305. ``broadcastables`` argument if given.
  306. broadcastables : dict
  307. Explicit value of ``broadcastables`` argument to
  308. :meth:`.TheanoPrinter.doprint`. If not None function will return this value unchanged.
  309. Returns
  310. =======
  311. dict
  312. Dictionary mapping elements of ``inputs`` to their "broadcastable"
  313. values (tuple of ``bool``\ s).
  314. """
  315. if dim is not None:
  316. return dict.fromkeys(inputs, (False,) * dim)
  317. if dims is not None:
  318. maxdim = max(dims.values())
  319. return {
  320. s: (False,) * d + (True,) * (maxdim - d)
  321. for s, d in dims.items()
  322. }
  323. if broadcastables is not None:
  324. return broadcastables
  325. return {}
  326. @doctest_depends_on(modules=('theano',))
  327. def theano_function(inputs, outputs, scalar=False, *,
  328. dim=None, dims=None, broadcastables=None, **kwargs):
  329. """
  330. Create a Theano function from SymPy expressions.
  331. .. deprecated:: 1.8
  332. ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
  333. Aesara. Use ``sympy.printing.aesaracode`` instead. See
  334. :ref:`theanocode-deprecated` for more information.
  335. The inputs and outputs are converted to Theano variables using
  336. :func:`.theano_code` and then passed to ``theano.function``.
  337. Parameters
  338. ==========
  339. inputs
  340. Sequence of symbols which constitute the inputs of the function.
  341. outputs
  342. Sequence of expressions which constitute the outputs(s) of the
  343. function. The free symbols of each expression must be a subset of
  344. ``inputs``.
  345. scalar : bool
  346. Convert 0-dimensional arrays in output to scalars. This will return a
  347. Python wrapper function around the Theano function object.
  348. cache : dict
  349. Cached Theano variables (see :class:`TheanoPrinter.cache
  350. <TheanoPrinter>`). Defaults to the module-level global cache.
  351. dtypes : dict
  352. Passed to :meth:`.TheanoPrinter.doprint`.
  353. broadcastables : dict
  354. Passed to :meth:`.TheanoPrinter.doprint`.
  355. dims : dict
  356. Alternative to ``broadcastables`` argument. Mapping from elements of
  357. ``inputs`` to integers indicating the dimension of their associated
  358. arrays/tensors. Overrides ``broadcastables`` argument if given.
  359. dim : int
  360. Another alternative to the ``broadcastables`` argument. Common number of
  361. dimensions to use for all arrays/tensors.
  362. ``theano_function([x, y], [...], dim=2)`` is equivalent to using
  363. ``broadcastables={x: (False, False), y: (False, False)}``.
  364. Returns
  365. =======
  366. callable
  367. A callable object which takes values of ``inputs`` as positional
  368. arguments and returns an output array for each of the expressions
  369. in ``outputs``. If ``outputs`` is a single expression the function will
  370. return a Numpy array, if it is a list of multiple expressions the
  371. function will return a list of arrays. See description of the ``squeeze``
  372. argument above for the behavior when a single output is passed in a list.
  373. The returned object will either be an instance of
  374. ``theano.compile.function_module.Function`` or a Python wrapper
  375. function around one. In both cases, the returned value will have a
  376. ``theano_function`` attribute which points to the return value of
  377. ``theano.function``.
  378. Examples
  379. ========
  380. >>> from sympy.abc import x, y, z
  381. >>> from sympy.printing.theanocode import theano_function
  382. A simple function with one input and one output:
  383. >>> f1 = theano_function([x], [x**2 - 1], scalar=True)
  384. >>> f1(3)
  385. 8.0
  386. A function with multiple inputs and one output:
  387. >>> f2 = theano_function([x, y, z], [(x**z + y**z)**(1/z)], scalar=True)
  388. >>> f2(3, 4, 2)
  389. 5.0
  390. A function with multiple inputs and multiple outputs:
  391. >>> f3 = theano_function([x, y], [x**2 + y**2, x**2 - y**2], scalar=True)
  392. >>> f3(2, 3)
  393. [13.0, -5.0]
  394. See also
  395. ========
  396. dim_handling
  397. """
  398. sympy_deprecation_warning(
  399. """
  400. sympy.printing.theanocode is deprecated. Theano has been renamed to Aesara. Use sympy.printing.aesaracode instead""",
  401. deprecated_since_version="1.8",
  402. active_deprecations_target='theanocode-deprecated')
  403. if not theano:
  404. raise ImportError("theano is required for theano_function")
  405. # Pop off non-theano keyword args
  406. cache = kwargs.pop('cache', {})
  407. dtypes = kwargs.pop('dtypes', {})
  408. broadcastables = dim_handling(
  409. inputs, dim=dim, dims=dims, broadcastables=broadcastables,
  410. )
  411. # Print inputs/outputs
  412. code = partial(theano_code, cache=cache, dtypes=dtypes,
  413. broadcastables=broadcastables)
  414. tinputs = list(map(code, inputs))
  415. toutputs = list(map(code, outputs))
  416. #fix constant expressions as variables
  417. toutputs = [output if isinstance(output, theano.Variable) else tt.as_tensor_variable(output) for output in toutputs]
  418. if len(toutputs) == 1:
  419. toutputs = toutputs[0]
  420. # Compile theano func
  421. func = theano.function(tinputs, toutputs, **kwargs)
  422. is_0d = [len(o.variable.broadcastable) == 0 for o in func.outputs]
  423. # No wrapper required
  424. if not scalar or not any(is_0d):
  425. func.theano_function = func
  426. return func
  427. # Create wrapper to convert 0-dimensional outputs to scalars
  428. def wrapper(*args):
  429. out = func(*args)
  430. # out can be array(1.0) or [array(1.0), array(2.0)]
  431. if is_sequence(out):
  432. return [o[()] if is_0d[i] else o for i, o in enumerate(out)]
  433. else:
  434. return out[()]
  435. wrapper.__wrapped__ = func
  436. wrapper.__doc__ = func.__doc__
  437. wrapper.theano_function = func
  438. return wrapper