numpy.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541
  1. from sympy.core import S
  2. from sympy.core.function import Lambda
  3. from sympy.core.power import Pow
  4. from .pycode import PythonCodePrinter, _known_functions_math, _print_known_const, _print_known_func, _unpack_integral_limits, ArrayPrinter
  5. from .codeprinter import CodePrinter
  6. _not_in_numpy = 'erf erfc factorial gamma loggamma'.split()
  7. _in_numpy = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_numpy]
  8. _known_functions_numpy = dict(_in_numpy, **{
  9. 'acos': 'arccos',
  10. 'acosh': 'arccosh',
  11. 'asin': 'arcsin',
  12. 'asinh': 'arcsinh',
  13. 'atan': 'arctan',
  14. 'atan2': 'arctan2',
  15. 'atanh': 'arctanh',
  16. 'exp2': 'exp2',
  17. 'sign': 'sign',
  18. 'logaddexp': 'logaddexp',
  19. 'logaddexp2': 'logaddexp2',
  20. 'isinf': 'isinf',
  21. 'isnan': 'isnan',
  22. })
  23. _known_constants_numpy = {
  24. 'Exp1': 'e',
  25. 'Pi': 'pi',
  26. 'EulerGamma': 'euler_gamma',
  27. 'NaN': 'nan',
  28. 'Infinity': 'inf',
  29. }
  30. _numpy_known_functions = {k: 'numpy.' + v for k, v in _known_functions_numpy.items()}
  31. _numpy_known_constants = {k: 'numpy.' + v for k, v in _known_constants_numpy.items()}
  32. class NumPyPrinter(ArrayPrinter, PythonCodePrinter):
  33. """
  34. Numpy printer which handles vectorized piecewise functions,
  35. logical operators, etc.
  36. """
  37. _module = 'numpy'
  38. _kf = _numpy_known_functions
  39. _kc = _numpy_known_constants
  40. def __init__(self, settings=None):
  41. """
  42. `settings` is passed to CodePrinter.__init__()
  43. `module` specifies the array module to use, currently 'NumPy', 'CuPy'
  44. or 'JAX'.
  45. """
  46. self.language = "Python with {}".format(self._module)
  47. self.printmethod = "_{}code".format(self._module)
  48. self._kf = {**PythonCodePrinter._kf, **self._kf}
  49. super().__init__(settings=settings)
  50. def _print_seq(self, seq):
  51. "General sequence printer: converts to tuple"
  52. # Print tuples here instead of lists because numba supports
  53. # tuples in nopython mode.
  54. delimiter=', '
  55. return '({},)'.format(delimiter.join(self._print(item) for item in seq))
  56. def _print_NegativeInfinity(self, expr):
  57. return '-' + self._print(S.Infinity)
  58. def _print_MatMul(self, expr):
  59. "Matrix multiplication printer"
  60. if expr.as_coeff_matrices()[0] is not S.One:
  61. expr_list = expr.as_coeff_matrices()[1]+[(expr.as_coeff_matrices()[0])]
  62. return '({})'.format(').dot('.join(self._print(i) for i in expr_list))
  63. return '({})'.format(').dot('.join(self._print(i) for i in expr.args))
  64. def _print_MatPow(self, expr):
  65. "Matrix power printer"
  66. return '{}({}, {})'.format(self._module_format(self._module + '.linalg.matrix_power'),
  67. self._print(expr.args[0]), self._print(expr.args[1]))
  68. def _print_Inverse(self, expr):
  69. "Matrix inverse printer"
  70. return '{}({})'.format(self._module_format(self._module + '.linalg.inv'),
  71. self._print(expr.args[0]))
  72. def _print_DotProduct(self, expr):
  73. # DotProduct allows any shape order, but numpy.dot does matrix
  74. # multiplication, so we have to make sure it gets 1 x n by n x 1.
  75. arg1, arg2 = expr.args
  76. if arg1.shape[0] != 1:
  77. arg1 = arg1.T
  78. if arg2.shape[1] != 1:
  79. arg2 = arg2.T
  80. return "%s(%s, %s)" % (self._module_format(self._module + '.dot'),
  81. self._print(arg1),
  82. self._print(arg2))
  83. def _print_MatrixSolve(self, expr):
  84. return "%s(%s, %s)" % (self._module_format(self._module + '.linalg.solve'),
  85. self._print(expr.matrix),
  86. self._print(expr.vector))
  87. def _print_ZeroMatrix(self, expr):
  88. return '{}({})'.format(self._module_format(self._module + '.zeros'),
  89. self._print(expr.shape))
  90. def _print_OneMatrix(self, expr):
  91. return '{}({})'.format(self._module_format(self._module + '.ones'),
  92. self._print(expr.shape))
  93. def _print_FunctionMatrix(self, expr):
  94. from sympy.abc import i, j
  95. lamda = expr.lamda
  96. if not isinstance(lamda, Lambda):
  97. lamda = Lambda((i, j), lamda(i, j))
  98. return '{}(lambda {}: {}, {})'.format(self._module_format(self._module + '.fromfunction'),
  99. ', '.join(self._print(arg) for arg in lamda.args[0]),
  100. self._print(lamda.args[1]), self._print(expr.shape))
  101. def _print_HadamardProduct(self, expr):
  102. func = self._module_format(self._module + '.multiply')
  103. return ''.join('{}({}, '.format(func, self._print(arg)) \
  104. for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]),
  105. ')' * (len(expr.args) - 1))
  106. def _print_KroneckerProduct(self, expr):
  107. func = self._module_format(self._module + '.kron')
  108. return ''.join('{}({}, '.format(func, self._print(arg)) \
  109. for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]),
  110. ')' * (len(expr.args) - 1))
  111. def _print_Adjoint(self, expr):
  112. return '{}({}({}))'.format(
  113. self._module_format(self._module + '.conjugate'),
  114. self._module_format(self._module + '.transpose'),
  115. self._print(expr.args[0]))
  116. def _print_DiagonalOf(self, expr):
  117. vect = '{}({})'.format(
  118. self._module_format(self._module + '.diag'),
  119. self._print(expr.arg))
  120. return '{}({}, (-1, 1))'.format(
  121. self._module_format(self._module + '.reshape'), vect)
  122. def _print_DiagMatrix(self, expr):
  123. return '{}({})'.format(self._module_format(self._module + '.diagflat'),
  124. self._print(expr.args[0]))
  125. def _print_DiagonalMatrix(self, expr):
  126. return '{}({}, {}({}, {}))'.format(self._module_format(self._module + '.multiply'),
  127. self._print(expr.arg), self._module_format(self._module + '.eye'),
  128. self._print(expr.shape[0]), self._print(expr.shape[1]))
  129. def _print_Piecewise(self, expr):
  130. "Piecewise function printer"
  131. from sympy.logic.boolalg import ITE, simplify_logic
  132. def print_cond(cond):
  133. """ Problem having an ITE in the cond. """
  134. if cond.has(ITE):
  135. return self._print(simplify_logic(cond))
  136. else:
  137. return self._print(cond)
  138. exprs = '[{}]'.format(','.join(self._print(arg.expr) for arg in expr.args))
  139. conds = '[{}]'.format(','.join(print_cond(arg.cond) for arg in expr.args))
  140. # If [default_value, True] is a (expr, cond) sequence in a Piecewise object
  141. # it will behave the same as passing the 'default' kwarg to select()
  142. # *as long as* it is the last element in expr.args.
  143. # If this is not the case, it may be triggered prematurely.
  144. return '{}({}, {}, default={})'.format(
  145. self._module_format(self._module + '.select'), conds, exprs,
  146. self._print(S.NaN))
  147. def _print_Relational(self, expr):
  148. "Relational printer for Equality and Unequality"
  149. op = {
  150. '==' :'equal',
  151. '!=' :'not_equal',
  152. '<' :'less',
  153. '<=' :'less_equal',
  154. '>' :'greater',
  155. '>=' :'greater_equal',
  156. }
  157. if expr.rel_op in op:
  158. lhs = self._print(expr.lhs)
  159. rhs = self._print(expr.rhs)
  160. return '{op}({lhs}, {rhs})'.format(op=self._module_format(self._module + '.'+op[expr.rel_op]),
  161. lhs=lhs, rhs=rhs)
  162. return super()._print_Relational(expr)
  163. def _print_And(self, expr):
  164. "Logical And printer"
  165. # We have to override LambdaPrinter because it uses Python 'and' keyword.
  166. # If LambdaPrinter didn't define it, we could use StrPrinter's
  167. # version of the function and add 'logical_and' to NUMPY_TRANSLATIONS.
  168. return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_and'), ','.join(self._print(i) for i in expr.args))
  169. def _print_Or(self, expr):
  170. "Logical Or printer"
  171. # We have to override LambdaPrinter because it uses Python 'or' keyword.
  172. # If LambdaPrinter didn't define it, we could use StrPrinter's
  173. # version of the function and add 'logical_or' to NUMPY_TRANSLATIONS.
  174. return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_or'), ','.join(self._print(i) for i in expr.args))
  175. def _print_Not(self, expr):
  176. "Logical Not printer"
  177. # We have to override LambdaPrinter because it uses Python 'not' keyword.
  178. # If LambdaPrinter didn't define it, we would still have to define our
  179. # own because StrPrinter doesn't define it.
  180. return '{}({})'.format(self._module_format(self._module + '.logical_not'), ','.join(self._print(i) for i in expr.args))
  181. def _print_Pow(self, expr, rational=False):
  182. # XXX Workaround for negative integer power error
  183. if expr.exp.is_integer and expr.exp.is_negative:
  184. expr = Pow(expr.base, expr.exp.evalf(), evaluate=False)
  185. return self._hprint_Pow(expr, rational=rational, sqrt=self._module + '.sqrt')
  186. def _helper_minimum_maximum(self, op: str, *args):
  187. if len(args) == 0:
  188. raise NotImplementedError(f"Need at least one argument for {op}")
  189. elif len(args) == 1:
  190. return self._print(args[0])
  191. _reduce = self._module_format('functools.reduce')
  192. s_args = [self._print(arg) for arg in args]
  193. return f"{_reduce}({op}, [{', '.join(s_args)}])"
  194. def _print_Min(self, expr):
  195. return self._print_minimum(expr)
  196. def _print_amin(self, expr):
  197. return '{}({}, axis={})'.format(self._module_format(self._module + '.amin'), self._print(expr.array), self._print(expr.axis))
  198. def _print_minimum(self, expr):
  199. op = self._module_format(self._module + '.minimum')
  200. return self._helper_minimum_maximum(op, *expr.args)
  201. def _print_Max(self, expr):
  202. return self._print_maximum(expr)
  203. def _print_amax(self, expr):
  204. return '{}({}, axis={})'.format(self._module_format(self._module + '.amax'), self._print(expr.array), self._print(expr.axis))
  205. def _print_maximum(self, expr):
  206. op = self._module_format(self._module + '.maximum')
  207. return self._helper_minimum_maximum(op, *expr.args)
  208. def _print_arg(self, expr):
  209. return "%s(%s)" % (self._module_format(self._module + '.angle'), self._print(expr.args[0]))
  210. def _print_im(self, expr):
  211. return "%s(%s)" % (self._module_format(self._module + '.imag'), self._print(expr.args[0]))
  212. def _print_Mod(self, expr):
  213. return "%s(%s)" % (self._module_format(self._module + '.mod'), ', '.join(
  214. (self._print(arg) for arg in expr.args)))
  215. def _print_re(self, expr):
  216. return "%s(%s)" % (self._module_format(self._module + '.real'), self._print(expr.args[0]))
  217. def _print_sinc(self, expr):
  218. return "%s(%s)" % (self._module_format(self._module + '.sinc'), self._print(expr.args[0]/S.Pi))
  219. def _print_MatrixBase(self, expr):
  220. if 0 in expr.shape:
  221. func = self._module_format(f'{self._module}.{self._zeros}')
  222. return f"{func}({self._print(expr.shape)})"
  223. func = self.known_functions.get(expr.__class__.__name__, None)
  224. if func is None:
  225. func = self._module_format(f'{self._module}.array')
  226. return "%s(%s)" % (func, self._print(expr.tolist()))
  227. def _print_Identity(self, expr):
  228. shape = expr.shape
  229. if all(dim.is_Integer for dim in shape):
  230. return "%s(%s)" % (self._module_format(self._module + '.eye'), self._print(expr.shape[0]))
  231. else:
  232. raise NotImplementedError("Symbolic matrix dimensions are not yet supported for identity matrices")
  233. def _print_BlockMatrix(self, expr):
  234. return '{}({})'.format(self._module_format(self._module + '.block'),
  235. self._print(expr.args[0].tolist()))
  236. def _print_NDimArray(self, expr):
  237. if expr.rank() == 0:
  238. func = self._module_format(f'{self._module}.array')
  239. return f"{func}({self._print(expr[()])})"
  240. if 0 in expr.shape:
  241. func = self._module_format(f'{self._module}.{self._zeros}')
  242. return f"{func}({self._print(expr.shape)})"
  243. func = self._module_format(f'{self._module}.array')
  244. return f"{func}({self._print(expr.tolist())})"
  245. _add = "add"
  246. _einsum = "einsum"
  247. _transpose = "transpose"
  248. _ones = "ones"
  249. _zeros = "zeros"
  250. _print_lowergamma = CodePrinter._print_not_supported
  251. _print_uppergamma = CodePrinter._print_not_supported
  252. _print_fresnelc = CodePrinter._print_not_supported
  253. _print_fresnels = CodePrinter._print_not_supported
  254. for func in _numpy_known_functions:
  255. setattr(NumPyPrinter, f'_print_{func}', _print_known_func)
  256. for const in _numpy_known_constants:
  257. setattr(NumPyPrinter, f'_print_{const}', _print_known_const)
  258. _known_functions_scipy_special = {
  259. 'Ei': 'expi',
  260. 'erf': 'erf',
  261. 'erfc': 'erfc',
  262. 'besselj': 'jv',
  263. 'bessely': 'yv',
  264. 'besseli': 'iv',
  265. 'besselk': 'kv',
  266. 'cosm1': 'cosm1',
  267. 'powm1': 'powm1',
  268. 'factorial': 'factorial',
  269. 'gamma': 'gamma',
  270. 'loggamma': 'gammaln',
  271. 'digamma': 'psi',
  272. 'polygamma': 'polygamma',
  273. 'RisingFactorial': 'poch',
  274. 'jacobi': 'eval_jacobi',
  275. 'gegenbauer': 'eval_gegenbauer',
  276. 'chebyshevt': 'eval_chebyt',
  277. 'chebyshevu': 'eval_chebyu',
  278. 'legendre': 'eval_legendre',
  279. 'hermite': 'eval_hermite',
  280. 'laguerre': 'eval_laguerre',
  281. 'assoc_laguerre': 'eval_genlaguerre',
  282. 'beta': 'beta',
  283. 'LambertW' : 'lambertw',
  284. }
  285. _known_constants_scipy_constants = {
  286. 'GoldenRatio': 'golden_ratio',
  287. 'Pi': 'pi',
  288. }
  289. _scipy_known_functions = {k : "scipy.special." + v for k, v in _known_functions_scipy_special.items()}
  290. _scipy_known_constants = {k : "scipy.constants." + v for k, v in _known_constants_scipy_constants.items()}
  291. class SciPyPrinter(NumPyPrinter):
  292. _kf = {**NumPyPrinter._kf, **_scipy_known_functions}
  293. _kc = {**NumPyPrinter._kc, **_scipy_known_constants}
  294. def __init__(self, settings=None):
  295. super().__init__(settings=settings)
  296. self.language = "Python with SciPy and NumPy"
  297. def _print_SparseRepMatrix(self, expr):
  298. i, j, data = [], [], []
  299. for (r, c), v in expr.todok().items():
  300. i.append(r)
  301. j.append(c)
  302. data.append(v)
  303. return "{name}(({data}, ({i}, {j})), shape={shape})".format(
  304. name=self._module_format('scipy.sparse.coo_matrix'),
  305. data=data, i=i, j=j, shape=expr.shape
  306. )
  307. _print_ImmutableSparseMatrix = _print_SparseRepMatrix
  308. # SciPy's lpmv has a different order of arguments from assoc_legendre
  309. def _print_assoc_legendre(self, expr):
  310. return "{0}({2}, {1}, {3})".format(
  311. self._module_format('scipy.special.lpmv'),
  312. self._print(expr.args[0]),
  313. self._print(expr.args[1]),
  314. self._print(expr.args[2]))
  315. def _print_lowergamma(self, expr):
  316. return "{0}({2})*{1}({2}, {3})".format(
  317. self._module_format('scipy.special.gamma'),
  318. self._module_format('scipy.special.gammainc'),
  319. self._print(expr.args[0]),
  320. self._print(expr.args[1]))
  321. def _print_uppergamma(self, expr):
  322. return "{0}({2})*{1}({2}, {3})".format(
  323. self._module_format('scipy.special.gamma'),
  324. self._module_format('scipy.special.gammaincc'),
  325. self._print(expr.args[0]),
  326. self._print(expr.args[1]))
  327. def _print_betainc(self, expr):
  328. betainc = self._module_format('scipy.special.betainc')
  329. beta = self._module_format('scipy.special.beta')
  330. args = [self._print(arg) for arg in expr.args]
  331. return f"({betainc}({args[0]}, {args[1]}, {args[3]}) - {betainc}({args[0]}, {args[1]}, {args[2]})) \
  332. * {beta}({args[0]}, {args[1]})"
  333. def _print_betainc_regularized(self, expr):
  334. return "{0}({1}, {2}, {4}) - {0}({1}, {2}, {3})".format(
  335. self._module_format('scipy.special.betainc'),
  336. self._print(expr.args[0]),
  337. self._print(expr.args[1]),
  338. self._print(expr.args[2]),
  339. self._print(expr.args[3]))
  340. def _print_fresnels(self, expr):
  341. return "{}({})[0]".format(
  342. self._module_format("scipy.special.fresnel"),
  343. self._print(expr.args[0]))
  344. def _print_fresnelc(self, expr):
  345. return "{}({})[1]".format(
  346. self._module_format("scipy.special.fresnel"),
  347. self._print(expr.args[0]))
  348. def _print_airyai(self, expr):
  349. return "{}({})[0]".format(
  350. self._module_format("scipy.special.airy"),
  351. self._print(expr.args[0]))
  352. def _print_airyaiprime(self, expr):
  353. return "{}({})[1]".format(
  354. self._module_format("scipy.special.airy"),
  355. self._print(expr.args[0]))
  356. def _print_airybi(self, expr):
  357. return "{}({})[2]".format(
  358. self._module_format("scipy.special.airy"),
  359. self._print(expr.args[0]))
  360. def _print_airybiprime(self, expr):
  361. return "{}({})[3]".format(
  362. self._module_format("scipy.special.airy"),
  363. self._print(expr.args[0]))
  364. def _print_bernoulli(self, expr):
  365. # scipy's bernoulli is inconsistent with SymPy's so rewrite
  366. return self._print(expr._eval_rewrite_as_zeta(*expr.args))
  367. def _print_harmonic(self, expr):
  368. return self._print(expr._eval_rewrite_as_zeta(*expr.args))
  369. def _print_Integral(self, e):
  370. integration_vars, limits = _unpack_integral_limits(e)
  371. if len(limits) == 1:
  372. # nicer (but not necessary) to prefer quad over nquad for 1D case
  373. module_str = self._module_format("scipy.integrate.quad")
  374. limit_str = "%s, %s" % tuple(map(self._print, limits[0]))
  375. else:
  376. module_str = self._module_format("scipy.integrate.nquad")
  377. limit_str = "({})".format(", ".join(
  378. "(%s, %s)" % tuple(map(self._print, l)) for l in limits))
  379. return "{}(lambda {}: {}, {})[0]".format(
  380. module_str,
  381. ", ".join(map(self._print, integration_vars)),
  382. self._print(e.args[0]),
  383. limit_str)
  384. def _print_Si(self, expr):
  385. return "{}({})[0]".format(
  386. self._module_format("scipy.special.sici"),
  387. self._print(expr.args[0]))
  388. def _print_Ci(self, expr):
  389. return "{}({})[1]".format(
  390. self._module_format("scipy.special.sici"),
  391. self._print(expr.args[0]))
  392. for func in _scipy_known_functions:
  393. setattr(SciPyPrinter, f'_print_{func}', _print_known_func)
  394. for const in _scipy_known_constants:
  395. setattr(SciPyPrinter, f'_print_{const}', _print_known_const)
  396. _cupy_known_functions = {k : "cupy." + v for k, v in _known_functions_numpy.items()}
  397. _cupy_known_constants = {k : "cupy." + v for k, v in _known_constants_numpy.items()}
  398. class CuPyPrinter(NumPyPrinter):
  399. """
  400. CuPy printer which handles vectorized piecewise functions,
  401. logical operators, etc.
  402. """
  403. _module = 'cupy'
  404. _kf = _cupy_known_functions
  405. _kc = _cupy_known_constants
  406. def __init__(self, settings=None):
  407. super().__init__(settings=settings)
  408. for func in _cupy_known_functions:
  409. setattr(CuPyPrinter, f'_print_{func}', _print_known_func)
  410. for const in _cupy_known_constants:
  411. setattr(CuPyPrinter, f'_print_{const}', _print_known_const)
  412. _jax_known_functions = {k: 'jax.numpy.' + v for k, v in _known_functions_numpy.items()}
  413. _jax_known_constants = {k: 'jax.numpy.' + v for k, v in _known_constants_numpy.items()}
  414. class JaxPrinter(NumPyPrinter):
  415. """
  416. JAX printer which handles vectorized piecewise functions,
  417. logical operators, etc.
  418. """
  419. _module = "jax.numpy"
  420. _kf = _jax_known_functions
  421. _kc = _jax_known_constants
  422. def __init__(self, settings=None):
  423. super().__init__(settings=settings)
  424. self.printmethod = '_jaxcode'
  425. # These need specific override to allow for the lack of "jax.numpy.reduce"
  426. def _print_And(self, expr):
  427. "Logical And printer"
  428. return "{}({}.asarray([{}]), axis=0)".format(
  429. self._module_format(self._module + ".all"),
  430. self._module_format(self._module),
  431. ",".join(self._print(i) for i in expr.args),
  432. )
  433. def _print_Or(self, expr):
  434. "Logical Or printer"
  435. return "{}({}.asarray([{}]), axis=0)".format(
  436. self._module_format(self._module + ".any"),
  437. self._module_format(self._module),
  438. ",".join(self._print(i) for i in expr.args),
  439. )
  440. for func in _jax_known_functions:
  441. setattr(JaxPrinter, f'_print_{func}', _print_known_func)
  442. for const in _jax_known_constants:
  443. setattr(JaxPrinter, f'_print_{const}', _print_known_const)