aesaracode.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. from __future__ import annotations
  2. import math
  3. from typing import Any
  4. from sympy.external import import_module
  5. from sympy.printing.printer import Printer
  6. from sympy.utilities.exceptions import sympy_deprecation_warning
  7. from sympy.utilities.iterables import is_sequence
  8. import sympy
  9. from functools import partial
  10. aesara = import_module('aesara')
  11. if aesara:
  12. aes = aesara.scalar
  13. aet = aesara.tensor
  14. from aesara.tensor import nlinalg
  15. from aesara.tensor.elemwise import Elemwise
  16. from aesara.tensor.elemwise import DimShuffle
  17. # `true_divide` replaced `true_div` in Aesara 2.8.11 (released 2023) to
  18. # match NumPy
  19. # XXX: Remove this when not needed to support older versions.
  20. true_divide = getattr(aet, 'true_divide', None)
  21. if true_divide is None:
  22. true_divide = aet.true_div
  23. mapping = {
  24. sympy.Add: aet.add,
  25. sympy.Mul: aet.mul,
  26. sympy.Abs: aet.abs,
  27. sympy.sign: aet.sgn,
  28. sympy.ceiling: aet.ceil,
  29. sympy.floor: aet.floor,
  30. sympy.log: aet.log,
  31. sympy.exp: aet.exp,
  32. sympy.sqrt: aet.sqrt,
  33. sympy.cos: aet.cos,
  34. sympy.acos: aet.arccos,
  35. sympy.sin: aet.sin,
  36. sympy.asin: aet.arcsin,
  37. sympy.tan: aet.tan,
  38. sympy.atan: aet.arctan,
  39. sympy.atan2: aet.arctan2,
  40. sympy.cosh: aet.cosh,
  41. sympy.acosh: aet.arccosh,
  42. sympy.sinh: aet.sinh,
  43. sympy.asinh: aet.arcsinh,
  44. sympy.tanh: aet.tanh,
  45. sympy.atanh: aet.arctanh,
  46. sympy.re: aet.real,
  47. sympy.im: aet.imag,
  48. sympy.arg: aet.angle,
  49. sympy.erf: aet.erf,
  50. sympy.gamma: aet.gamma,
  51. sympy.loggamma: aet.gammaln,
  52. sympy.Pow: aet.pow,
  53. sympy.Eq: aet.eq,
  54. sympy.StrictGreaterThan: aet.gt,
  55. sympy.StrictLessThan: aet.lt,
  56. sympy.LessThan: aet.le,
  57. sympy.GreaterThan: aet.ge,
  58. sympy.And: aet.bitwise_and, # bitwise
  59. sympy.Or: aet.bitwise_or, # bitwise
  60. sympy.Not: aet.invert, # bitwise
  61. sympy.Xor: aet.bitwise_xor, # bitwise
  62. sympy.Max: aet.maximum, # Sympy accept >2 inputs, Aesara only 2
  63. sympy.Min: aet.minimum, # Sympy accept >2 inputs, Aesara only 2
  64. sympy.conjugate: aet.conj,
  65. sympy.core.numbers.ImaginaryUnit: lambda:aet.complex(0,1),
  66. # Matrices
  67. sympy.MatAdd: Elemwise(aes.add),
  68. sympy.HadamardProduct: Elemwise(aes.mul),
  69. sympy.Trace: nlinalg.trace,
  70. sympy.Determinant : nlinalg.det,
  71. sympy.Inverse: nlinalg.matrix_inverse,
  72. sympy.Transpose: DimShuffle((False, False), [1, 0]),
  73. }
  74. class AesaraPrinter(Printer):
  75. """
  76. .. deprecated:: 1.14.
  77. The ``Aesara Code printing`` is deprecated.See its documentation for
  78. more information. See :ref:`deprecated-aesaraprinter` for details.
  79. Code printer which creates Aesara symbolic expression graphs.
  80. Parameters
  81. ==========
  82. cache : dict
  83. Cache dictionary to use. If None (default) will use
  84. the global cache. To create a printer which does not depend on or alter
  85. global state pass an empty dictionary. Note: the dictionary is not
  86. copied on initialization of the printer and will be updated in-place,
  87. so using the same dict object when creating multiple printers or making
  88. multiple calls to :func:`.aesara_code` or :func:`.aesara_function` means
  89. the cache is shared between all these applications.
  90. Attributes
  91. ==========
  92. cache : dict
  93. A cache of Aesara variables which have been created for SymPy
  94. symbol-like objects (e.g. :class:`sympy.core.symbol.Symbol` or
  95. :class:`sympy.matrices.expressions.MatrixSymbol`). This is used to
  96. ensure that all references to a given symbol in an expression (or
  97. multiple expressions) are printed as the same Aesara variable, which is
  98. created only once. Symbols are differentiated only by name and type. The
  99. format of the cache's contents should be considered opaque to the user.
  100. """
  101. printmethod = "_aesara"
  102. def __init__(self, *args, **kwargs):
  103. self.cache = kwargs.pop('cache', {})
  104. super().__init__(*args, **kwargs)
  105. def _get_key(self, s, name=None, dtype=None, broadcastable=None):
  106. """ Get the cache key for a SymPy object.
  107. Parameters
  108. ==========
  109. s : sympy.core.basic.Basic
  110. SymPy object to get key for.
  111. name : str
  112. Name of object, if it does not have a ``name`` attribute.
  113. """
  114. if name is None:
  115. name = s.name
  116. return (name, type(s), s.args, dtype, broadcastable)
  117. def _get_or_create(self, s, name=None, dtype=None, broadcastable=None):
  118. """
  119. Get the Aesara variable for a SymPy symbol from the cache, or create it
  120. if it does not exist.
  121. """
  122. # Defaults
  123. if name is None:
  124. name = s.name
  125. if dtype is None:
  126. dtype = 'floatX'
  127. if broadcastable is None:
  128. broadcastable = ()
  129. key = self._get_key(s, name, dtype=dtype, broadcastable=broadcastable)
  130. if key in self.cache:
  131. return self.cache[key]
  132. value = aet.tensor(name=name, dtype=dtype, shape=broadcastable)
  133. self.cache[key] = value
  134. return value
  135. def _print_Symbol(self, s, **kwargs):
  136. dtype = kwargs.get('dtypes', {}).get(s)
  137. bc = kwargs.get('broadcastables', {}).get(s)
  138. return self._get_or_create(s, dtype=dtype, broadcastable=bc)
  139. def _print_AppliedUndef(self, s, **kwargs):
  140. name = str(type(s)) + '_' + str(s.args[0])
  141. dtype = kwargs.get('dtypes', {}).get(s)
  142. bc = kwargs.get('broadcastables', {}).get(s)
  143. return self._get_or_create(s, name=name, dtype=dtype, broadcastable=bc)
  144. def _print_Basic(self, expr, **kwargs):
  145. op = mapping[type(expr)]
  146. children = [self._print(arg, **kwargs) for arg in expr.args]
  147. return op(*children)
  148. def _print_Number(self, n, **kwargs):
  149. # Integers already taken care of below, interpret as float
  150. return float(n.evalf())
  151. def _print_MatrixSymbol(self, X, **kwargs):
  152. dtype = kwargs.get('dtypes', {}).get(X)
  153. return self._get_or_create(X, dtype=dtype, broadcastable=(None, None))
  154. def _print_DenseMatrix(self, X, **kwargs):
  155. if not hasattr(aet, 'stacklists'):
  156. raise NotImplementedError(
  157. "Matrix translation not yet supported in this version of Aesara")
  158. return aet.stacklists([
  159. [self._print(arg, **kwargs) for arg in L]
  160. for L in X.tolist()
  161. ])
  162. _print_ImmutableMatrix = _print_ImmutableDenseMatrix = _print_DenseMatrix
  163. def _print_MatMul(self, expr, **kwargs):
  164. children = [self._print(arg, **kwargs) for arg in expr.args]
  165. result = children[0]
  166. for child in children[1:]:
  167. result = aet.dot(result, child)
  168. return result
  169. def _print_MatPow(self, expr, **kwargs):
  170. children = [self._print(arg, **kwargs) for arg in expr.args]
  171. result = 1
  172. if isinstance(children[1], int) and children[1] > 0:
  173. for i in range(children[1]):
  174. result = aet.dot(result, children[0])
  175. else:
  176. raise NotImplementedError('''Only non-negative integer
  177. powers of matrices can be handled by Aesara at the moment''')
  178. return result
  179. def _print_MatrixSlice(self, expr, **kwargs):
  180. parent = self._print(expr.parent, **kwargs)
  181. rowslice = self._print(slice(*expr.rowslice), **kwargs)
  182. colslice = self._print(slice(*expr.colslice), **kwargs)
  183. return parent[rowslice, colslice]
  184. def _print_BlockMatrix(self, expr, **kwargs):
  185. nrows, ncols = expr.blocks.shape
  186. blocks = [[self._print(expr.blocks[r, c], **kwargs)
  187. for c in range(ncols)]
  188. for r in range(nrows)]
  189. return aet.join(0, *[aet.join(1, *row) for row in blocks])
  190. def _print_slice(self, expr, **kwargs):
  191. return slice(*[self._print(i, **kwargs)
  192. if isinstance(i, sympy.Basic) else i
  193. for i in (expr.start, expr.stop, expr.step)])
  194. def _print_Pi(self, expr, **kwargs):
  195. return math.pi
  196. def _print_Piecewise(self, expr, **kwargs):
  197. import numpy as np
  198. e, cond = expr.args[0].args # First condition and corresponding value
  199. # Print conditional expression and value for first condition
  200. p_cond = self._print(cond, **kwargs)
  201. p_e = self._print(e, **kwargs)
  202. # One condition only
  203. if len(expr.args) == 1:
  204. # Return value if condition else NaN
  205. return aet.switch(p_cond, p_e, np.nan)
  206. # Return value_1 if condition_1 else evaluate remaining conditions
  207. p_remaining = self._print(sympy.Piecewise(*expr.args[1:]), **kwargs)
  208. return aet.switch(p_cond, p_e, p_remaining)
  209. def _print_Rational(self, expr, **kwargs):
  210. return true_divide(self._print(expr.p, **kwargs),
  211. self._print(expr.q, **kwargs))
  212. def _print_Integer(self, expr, **kwargs):
  213. return expr.p
  214. def _print_factorial(self, expr, **kwargs):
  215. return self._print(sympy.gamma(expr.args[0] + 1), **kwargs)
  216. def _print_Derivative(self, deriv, **kwargs):
  217. from aesara.gradient import Rop
  218. rv = self._print(deriv.expr, **kwargs)
  219. for var in deriv.variables:
  220. var = self._print(var, **kwargs)
  221. rv = Rop(rv, var, aet.ones_like(var))
  222. return rv
  223. def emptyPrinter(self, expr):
  224. return expr
  225. def doprint(self, expr, dtypes=None, broadcastables=None):
  226. """ Convert a SymPy expression to a Aesara graph variable.
  227. The ``dtypes`` and ``broadcastables`` arguments are used to specify the
  228. data type, dimension, and broadcasting behavior of the Aesara variables
  229. corresponding to the free symbols in ``expr``. Each is a mapping from
  230. SymPy symbols to the value of the corresponding argument to
  231. ``aesara.tensor.var.TensorVariable``.
  232. See the corresponding `documentation page`__ for more information on
  233. broadcasting in Aesara.
  234. .. __: https://aesara.readthedocs.io/en/latest/reference/tensor/shapes.html#broadcasting
  235. Parameters
  236. ==========
  237. expr : sympy.core.expr.Expr
  238. SymPy expression to print.
  239. dtypes : dict
  240. Mapping from SymPy symbols to Aesara datatypes to use when creating
  241. new Aesara variables for those symbols. Corresponds to the ``dtype``
  242. argument to ``aesara.tensor.var.TensorVariable``. Defaults to ``'floatX'``
  243. for symbols not included in the mapping.
  244. broadcastables : dict
  245. Mapping from SymPy symbols to the value of the ``broadcastable``
  246. argument to ``aesara.tensor.var.TensorVariable`` to use when creating Aesara
  247. variables for those symbols. Defaults to the empty tuple for symbols
  248. not included in the mapping (resulting in a scalar).
  249. Returns
  250. =======
  251. aesara.graph.basic.Variable
  252. A variable corresponding to the expression's value in a Aesara
  253. symbolic expression graph.
  254. """
  255. if dtypes is None:
  256. dtypes = {}
  257. if broadcastables is None:
  258. broadcastables = {}
  259. return self._print(expr, dtypes=dtypes, broadcastables=broadcastables)
  260. global_cache: dict[Any, Any] = {}
  261. def aesara_code(expr, cache=None, **kwargs):
  262. """
  263. Convert a SymPy expression into a Aesara graph variable.
  264. Parameters
  265. ==========
  266. expr : sympy.core.expr.Expr
  267. SymPy expression object to convert.
  268. cache : dict
  269. Cached Aesara variables (see :class:`AesaraPrinter.cache
  270. <AesaraPrinter>`). Defaults to the module-level global cache.
  271. dtypes : dict
  272. Passed to :meth:`.AesaraPrinter.doprint`.
  273. broadcastables : dict
  274. Passed to :meth:`.AesaraPrinter.doprint`.
  275. Returns
  276. =======
  277. aesara.graph.basic.Variable
  278. A variable corresponding to the expression's value in a Aesara symbolic
  279. expression graph.
  280. """
  281. sympy_deprecation_warning(
  282. """
  283. The aesara_code function is deprecated.
  284. """,
  285. deprecated_since_version="1.14",
  286. active_deprecations_target='deprecated-aesaraprinter',
  287. )
  288. if not aesara:
  289. raise ImportError("aesara is required for aesara_code")
  290. if cache is None:
  291. cache = global_cache
  292. return AesaraPrinter(cache=cache, settings={}).doprint(expr, **kwargs)
  293. def dim_handling(inputs, dim=None, dims=None, broadcastables=None):
  294. r"""
  295. Get value of ``broadcastables`` argument to :func:`.aesara_code` from
  296. keyword arguments to :func:`.aesara_function`.
  297. Included for backwards compatibility.
  298. Parameters
  299. ==========
  300. inputs
  301. Sequence of input symbols.
  302. dim : int
  303. Common number of dimensions for all inputs. Overrides other arguments
  304. if given.
  305. dims : dict
  306. Mapping from input symbols to number of dimensions. Overrides
  307. ``broadcastables`` argument if given.
  308. broadcastables : dict
  309. Explicit value of ``broadcastables`` argument to
  310. :meth:`.AesaraPrinter.doprint`. If not None function will return this value unchanged.
  311. Returns
  312. =======
  313. dict
  314. Dictionary mapping elements of ``inputs`` to their "broadcastable"
  315. values (tuple of ``bool``\ s).
  316. """
  317. if dim is not None:
  318. return dict.fromkeys(inputs, (False,) * dim)
  319. if dims is not None:
  320. maxdim = max(dims.values())
  321. return {
  322. s: (False,) * d + (True,) * (maxdim - d)
  323. for s, d in dims.items()
  324. }
  325. if broadcastables is not None:
  326. return broadcastables
  327. return {}
  328. def aesara_function(inputs, outputs, scalar=False, *,
  329. dim=None, dims=None, broadcastables=None, **kwargs):
  330. """
  331. Create a Aesara function from SymPy expressions.
  332. The inputs and outputs are converted to Aesara variables using
  333. :func:`.aesara_code` and then passed to ``aesara.function``.
  334. Parameters
  335. ==========
  336. inputs
  337. Sequence of symbols which constitute the inputs of the function.
  338. outputs
  339. Sequence of expressions which constitute the outputs(s) of the
  340. function. The free symbols of each expression must be a subset of
  341. ``inputs``.
  342. scalar : bool
  343. Convert 0-dimensional arrays in output to scalars. This will return a
  344. Python wrapper function around the Aesara function object.
  345. cache : dict
  346. Cached Aesara variables (see :class:`AesaraPrinter.cache
  347. <AesaraPrinter>`). Defaults to the module-level global cache.
  348. dtypes : dict
  349. Passed to :meth:`.AesaraPrinter.doprint`.
  350. broadcastables : dict
  351. Passed to :meth:`.AesaraPrinter.doprint`.
  352. dims : dict
  353. Alternative to ``broadcastables`` argument. Mapping from elements of
  354. ``inputs`` to integers indicating the dimension of their associated
  355. arrays/tensors. Overrides ``broadcastables`` argument if given.
  356. dim : int
  357. Another alternative to the ``broadcastables`` argument. Common number of
  358. dimensions to use for all arrays/tensors.
  359. ``aesara_function([x, y], [...], dim=2)`` is equivalent to using
  360. ``broadcastables={x: (False, False), y: (False, False)}``.
  361. Returns
  362. =======
  363. callable
  364. A callable object which takes values of ``inputs`` as positional
  365. arguments and returns an output array for each of the expressions
  366. in ``outputs``. If ``outputs`` is a single expression the function will
  367. return a Numpy array, if it is a list of multiple expressions the
  368. function will return a list of arrays. See description of the ``squeeze``
  369. argument above for the behavior when a single output is passed in a list.
  370. The returned object will either be an instance of
  371. ``aesara.compile.function.types.Function`` or a Python wrapper
  372. function around one. In both cases, the returned value will have a
  373. ``aesara_function`` attribute which points to the return value of
  374. ``aesara.function``.
  375. Examples
  376. ========
  377. >>> from sympy.abc import x, y, z
  378. >>> from sympy.printing.aesaracode import aesara_function
  379. A simple function with one input and one output:
  380. >>> f1 = aesara_function([x], [x**2 - 1], scalar=True)
  381. >>> f1(3)
  382. 8.0
  383. A function with multiple inputs and one output:
  384. >>> f2 = aesara_function([x, y, z], [(x**z + y**z)**(1/z)], scalar=True)
  385. >>> f2(3, 4, 2)
  386. 5.0
  387. A function with multiple inputs and multiple outputs:
  388. >>> f3 = aesara_function([x, y], [x**2 + y**2, x**2 - y**2], scalar=True)
  389. >>> f3(2, 3)
  390. [13.0, -5.0]
  391. See also
  392. ========
  393. dim_handling
  394. """
  395. sympy_deprecation_warning(
  396. """
  397. The aesara_function function is deprecated.
  398. """,
  399. deprecated_since_version="1.14",
  400. active_deprecations_target='deprecated-aesaraprinter',
  401. )
  402. if not aesara:
  403. raise ImportError("Aesara is required for aesara_function")
  404. # Pop off non-aesara keyword args
  405. cache = kwargs.pop('cache', {})
  406. dtypes = kwargs.pop('dtypes', {})
  407. broadcastables = dim_handling(
  408. inputs, dim=dim, dims=dims, broadcastables=broadcastables,
  409. )
  410. # Print inputs/outputs
  411. code = partial(aesara_code, cache=cache, dtypes=dtypes,
  412. broadcastables=broadcastables)
  413. tinputs = list(map(code, inputs))
  414. toutputs = list(map(code, outputs))
  415. #fix constant expressions as variables
  416. toutputs = [output if isinstance(output, aesara.graph.basic.Variable) else aet.as_tensor_variable(output) for output in toutputs]
  417. if len(toutputs) == 1:
  418. toutputs = toutputs[0]
  419. # Compile aesara func
  420. func = aesara.function(tinputs, toutputs, **kwargs)
  421. is_0d = [len(o.variable.broadcastable) == 0 for o in func.outputs]
  422. # No wrapper required
  423. if not scalar or not any(is_0d):
  424. func.aesara_function = func
  425. return func
  426. # Create wrapper to convert 0-dimensional outputs to scalars
  427. def wrapper(*args):
  428. out = func(*args)
  429. # out can be array(1.0) or [array(1.0), array(2.0)]
  430. if is_sequence(out):
  431. return [o[()] if is_0d[i] else o for i, o in enumerate(out)]
  432. else:
  433. return out[()]
  434. wrapper.__wrapped__ = func
  435. wrapper.__doc__ = func.__doc__
  436. wrapper.aesara_function = func
  437. return wrapper