fnodes.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658
  1. """
  2. AST nodes specific to Fortran.
  3. The functions defined in this module allows the user to express functions such as ``dsign``
  4. as a SymPy function for symbolic manipulation.
  5. """
  6. from __future__ import annotations
  7. from sympy.codegen.ast import (
  8. Attribute, CodeBlock, FunctionCall, Node, none, String,
  9. Token, _mk_Tuple, Variable
  10. )
  11. from sympy.core.basic import Basic
  12. from sympy.core.containers import Tuple
  13. from sympy.core.expr import Expr
  14. from sympy.core.function import Function
  15. from sympy.core.numbers import Float, Integer
  16. from sympy.core.symbol import Str
  17. from sympy.core.sympify import sympify
  18. from sympy.logic import true, false
  19. from sympy.utilities.iterables import iterable
  20. pure = Attribute('pure')
  21. elemental = Attribute('elemental') # (all elemental procedures are also pure)
  22. intent_in = Attribute('intent_in')
  23. intent_out = Attribute('intent_out')
  24. intent_inout = Attribute('intent_inout')
  25. allocatable = Attribute('allocatable')
  26. class Program(Token):
  27. """ Represents a 'program' block in Fortran.
  28. Examples
  29. ========
  30. >>> from sympy.codegen.ast import Print
  31. >>> from sympy.codegen.fnodes import Program
  32. >>> prog = Program('myprogram', [Print([42])])
  33. >>> from sympy import fcode
  34. >>> print(fcode(prog, source_format='free'))
  35. program myprogram
  36. print *, 42
  37. end program
  38. """
  39. __slots__ = _fields = ('name', 'body')
  40. _construct_name = String
  41. _construct_body = staticmethod(lambda body: CodeBlock(*body))
  42. class use_rename(Token):
  43. """ Represents a renaming in a use statement in Fortran.
  44. Examples
  45. ========
  46. >>> from sympy.codegen.fnodes import use_rename, use
  47. >>> from sympy import fcode
  48. >>> ren = use_rename("thingy", "convolution2d")
  49. >>> print(fcode(ren, source_format='free'))
  50. thingy => convolution2d
  51. >>> full = use('signallib', only=['snr', ren])
  52. >>> print(fcode(full, source_format='free'))
  53. use signallib, only: snr, thingy => convolution2d
  54. """
  55. __slots__ = _fields = ('local', 'original')
  56. _construct_local = String
  57. _construct_original = String
  58. def _name(arg):
  59. if hasattr(arg, 'name'):
  60. return arg.name
  61. else:
  62. return String(arg)
  63. class use(Token):
  64. """ Represents a use statement in Fortran.
  65. Examples
  66. ========
  67. >>> from sympy.codegen.fnodes import use
  68. >>> from sympy import fcode
  69. >>> fcode(use('signallib'), source_format='free')
  70. 'use signallib'
  71. >>> fcode(use('signallib', [('metric', 'snr')]), source_format='free')
  72. 'use signallib, metric => snr'
  73. >>> fcode(use('signallib', only=['snr', 'convolution2d']), source_format='free')
  74. 'use signallib, only: snr, convolution2d'
  75. """
  76. __slots__ = _fields = ('namespace', 'rename', 'only')
  77. defaults = {'rename': none, 'only': none}
  78. _construct_namespace = staticmethod(_name)
  79. _construct_rename = staticmethod(lambda args: Tuple(*[arg if isinstance(arg, use_rename) else use_rename(*arg) for arg in args]))
  80. _construct_only = staticmethod(lambda args: Tuple(*[arg if isinstance(arg, use_rename) else _name(arg) for arg in args]))
  81. class Module(Token):
  82. """ Represents a module in Fortran.
  83. Examples
  84. ========
  85. >>> from sympy.codegen.fnodes import Module
  86. >>> from sympy import fcode
  87. >>> print(fcode(Module('signallib', ['implicit none'], []), source_format='free'))
  88. module signallib
  89. implicit none
  90. <BLANKLINE>
  91. contains
  92. <BLANKLINE>
  93. <BLANKLINE>
  94. end module
  95. """
  96. __slots__ = _fields = ('name', 'declarations', 'definitions')
  97. defaults = {'declarations': Tuple()}
  98. _construct_name = String
  99. @classmethod
  100. def _construct_declarations(cls, args):
  101. args = [Str(arg) if isinstance(arg, str) else arg for arg in args]
  102. return CodeBlock(*args)
  103. _construct_definitions = staticmethod(lambda arg: CodeBlock(*arg))
  104. class Subroutine(Node):
  105. """ Represents a subroutine in Fortran.
  106. Examples
  107. ========
  108. >>> from sympy import fcode, symbols
  109. >>> from sympy.codegen.ast import Print
  110. >>> from sympy.codegen.fnodes import Subroutine
  111. >>> x, y = symbols('x y', real=True)
  112. >>> sub = Subroutine('mysub', [x, y], [Print([x**2 + y**2, x*y])])
  113. >>> print(fcode(sub, source_format='free', standard=2003))
  114. subroutine mysub(x, y)
  115. real*8 :: x
  116. real*8 :: y
  117. print *, x**2 + y**2, x*y
  118. end subroutine
  119. """
  120. __slots__ = ('name', 'parameters', 'body')
  121. _fields = __slots__ + Node._fields
  122. _construct_name = String
  123. _construct_parameters = staticmethod(lambda params: Tuple(*map(Variable.deduced, params)))
  124. @classmethod
  125. def _construct_body(cls, itr):
  126. if isinstance(itr, CodeBlock):
  127. return itr
  128. else:
  129. return CodeBlock(*itr)
  130. class SubroutineCall(Token):
  131. """ Represents a call to a subroutine in Fortran.
  132. Examples
  133. ========
  134. >>> from sympy.codegen.fnodes import SubroutineCall
  135. >>> from sympy import fcode
  136. >>> fcode(SubroutineCall('mysub', 'x y'.split()))
  137. ' call mysub(x, y)'
  138. """
  139. __slots__ = _fields = ('name', 'subroutine_args')
  140. _construct_name = staticmethod(_name)
  141. _construct_subroutine_args = staticmethod(_mk_Tuple)
  142. class Do(Token):
  143. """ Represents a Do loop in in Fortran.
  144. Examples
  145. ========
  146. >>> from sympy import fcode, symbols
  147. >>> from sympy.codegen.ast import aug_assign, Print
  148. >>> from sympy.codegen.fnodes import Do
  149. >>> i, n = symbols('i n', integer=True)
  150. >>> r = symbols('r', real=True)
  151. >>> body = [aug_assign(r, '+', 1/i), Print([i, r])]
  152. >>> do1 = Do(body, i, 1, n)
  153. >>> print(fcode(do1, source_format='free'))
  154. do i = 1, n
  155. r = r + 1d0/i
  156. print *, i, r
  157. end do
  158. >>> do2 = Do(body, i, 1, n, 2)
  159. >>> print(fcode(do2, source_format='free'))
  160. do i = 1, n, 2
  161. r = r + 1d0/i
  162. print *, i, r
  163. end do
  164. """
  165. __slots__ = _fields = ('body', 'counter', 'first', 'last', 'step', 'concurrent')
  166. defaults = {'step': Integer(1), 'concurrent': false}
  167. _construct_body = staticmethod(lambda body: CodeBlock(*body))
  168. _construct_counter = staticmethod(sympify)
  169. _construct_first = staticmethod(sympify)
  170. _construct_last = staticmethod(sympify)
  171. _construct_step = staticmethod(sympify)
  172. _construct_concurrent = staticmethod(lambda arg: true if arg else false)
  173. class ArrayConstructor(Token):
  174. """ Represents an array constructor.
  175. Examples
  176. ========
  177. >>> from sympy import fcode
  178. >>> from sympy.codegen.fnodes import ArrayConstructor
  179. >>> ac = ArrayConstructor([1, 2, 3])
  180. >>> fcode(ac, standard=95, source_format='free')
  181. '(/1, 2, 3/)'
  182. >>> fcode(ac, standard=2003, source_format='free')
  183. '[1, 2, 3]'
  184. """
  185. __slots__ = _fields = ('elements',)
  186. _construct_elements = staticmethod(_mk_Tuple)
  187. class ImpliedDoLoop(Token):
  188. """ Represents an implied do loop in Fortran.
  189. Examples
  190. ========
  191. >>> from sympy import Symbol, fcode
  192. >>> from sympy.codegen.fnodes import ImpliedDoLoop, ArrayConstructor
  193. >>> i = Symbol('i', integer=True)
  194. >>> idl = ImpliedDoLoop(i**3, i, -3, 3, 2) # -27, -1, 1, 27
  195. >>> ac = ArrayConstructor([-28, idl, 28]) # -28, -27, -1, 1, 27, 28
  196. >>> fcode(ac, standard=2003, source_format='free')
  197. '[-28, (i**3, i = -3, 3, 2), 28]'
  198. """
  199. __slots__ = _fields = ('expr', 'counter', 'first', 'last', 'step')
  200. defaults = {'step': Integer(1)}
  201. _construct_expr = staticmethod(sympify)
  202. _construct_counter = staticmethod(sympify)
  203. _construct_first = staticmethod(sympify)
  204. _construct_last = staticmethod(sympify)
  205. _construct_step = staticmethod(sympify)
  206. class Extent(Basic):
  207. """ Represents a dimension extent.
  208. Examples
  209. ========
  210. >>> from sympy.codegen.fnodes import Extent
  211. >>> e = Extent(-3, 3) # -3, -2, -1, 0, 1, 2, 3
  212. >>> from sympy import fcode
  213. >>> fcode(e, source_format='free')
  214. '-3:3'
  215. >>> from sympy.codegen.ast import Variable, real
  216. >>> from sympy.codegen.fnodes import dimension, intent_out
  217. >>> dim = dimension(e, e)
  218. >>> arr = Variable('x', real, attrs=[dim, intent_out])
  219. >>> fcode(arr.as_Declaration(), source_format='free', standard=2003)
  220. 'real*8, dimension(-3:3, -3:3), intent(out) :: x'
  221. """
  222. def __new__(cls, *args):
  223. if len(args) == 2:
  224. low, high = args
  225. return Basic.__new__(cls, sympify(low), sympify(high))
  226. elif len(args) == 0 or (len(args) == 1 and args[0] in (':', None)):
  227. return Basic.__new__(cls) # assumed shape
  228. else:
  229. raise ValueError("Expected 0 or 2 args (or one argument == None or ':')")
  230. def _sympystr(self, printer):
  231. if len(self.args) == 0:
  232. return ':'
  233. return ":".join(str(arg) for arg in self.args)
  234. assumed_extent = Extent() # or Extent(':'), Extent(None)
  235. def dimension(*args):
  236. """ Creates a 'dimension' Attribute with (up to 7) extents.
  237. Examples
  238. ========
  239. >>> from sympy import fcode
  240. >>> from sympy.codegen.fnodes import dimension, intent_in
  241. >>> dim = dimension('2', ':') # 2 rows, runtime determined number of columns
  242. >>> from sympy.codegen.ast import Variable, integer
  243. >>> arr = Variable('a', integer, attrs=[dim, intent_in])
  244. >>> fcode(arr.as_Declaration(), source_format='free', standard=2003)
  245. 'integer*4, dimension(2, :), intent(in) :: a'
  246. """
  247. if len(args) > 7:
  248. raise ValueError("Fortran only supports up to 7 dimensional arrays")
  249. parameters = []
  250. for arg in args:
  251. if isinstance(arg, Extent):
  252. parameters.append(arg)
  253. elif isinstance(arg, str):
  254. if arg == ':':
  255. parameters.append(Extent())
  256. else:
  257. parameters.append(String(arg))
  258. elif iterable(arg):
  259. parameters.append(Extent(*arg))
  260. else:
  261. parameters.append(sympify(arg))
  262. if len(args) == 0:
  263. raise ValueError("Need at least one dimension")
  264. return Attribute('dimension', parameters)
  265. assumed_size = dimension('*')
  266. def array(symbol, dim, intent=None, *, attrs=(), value=None, type=None):
  267. """ Convenience function for creating a Variable instance for a Fortran array.
  268. Parameters
  269. ==========
  270. symbol : symbol
  271. dim : Attribute or iterable
  272. If dim is an ``Attribute`` it need to have the name 'dimension'. If it is
  273. not an ``Attribute``, then it is passed to :func:`dimension` as ``*dim``
  274. intent : str
  275. One of: 'in', 'out', 'inout' or None
  276. \\*\\*kwargs:
  277. Keyword arguments for ``Variable`` ('type' & 'value')
  278. Examples
  279. ========
  280. >>> from sympy import fcode
  281. >>> from sympy.codegen.ast import integer, real
  282. >>> from sympy.codegen.fnodes import array
  283. >>> arr = array('a', '*', 'in', type=integer)
  284. >>> print(fcode(arr.as_Declaration(), source_format='free', standard=2003))
  285. integer*4, dimension(*), intent(in) :: a
  286. >>> x = array('x', [3, ':', ':'], intent='out', type=real)
  287. >>> print(fcode(x.as_Declaration(value=1), source_format='free', standard=2003))
  288. real*8, dimension(3, :, :), intent(out) :: x = 1
  289. """
  290. if isinstance(dim, Attribute):
  291. if str(dim.name) != 'dimension':
  292. raise ValueError("Got an unexpected Attribute argument as dim: %s" % str(dim))
  293. else:
  294. dim = dimension(*dim)
  295. attrs = list(attrs) + [dim]
  296. if intent is not None:
  297. if intent not in (intent_in, intent_out, intent_inout):
  298. intent = {'in': intent_in, 'out': intent_out, 'inout': intent_inout}[intent]
  299. attrs.append(intent)
  300. if type is None:
  301. return Variable.deduced(symbol, value=value, attrs=attrs)
  302. else:
  303. return Variable(symbol, type, value=value, attrs=attrs)
  304. def _printable(arg):
  305. return String(arg) if isinstance(arg, str) else sympify(arg)
  306. def allocated(array):
  307. """ Creates an AST node for a function call to Fortran's "allocated(...)"
  308. Examples
  309. ========
  310. >>> from sympy import fcode
  311. >>> from sympy.codegen.fnodes import allocated
  312. >>> alloc = allocated('x')
  313. >>> fcode(alloc, source_format='free')
  314. 'allocated(x)'
  315. """
  316. return FunctionCall('allocated', [_printable(array)])
  317. def lbound(array, dim=None, kind=None):
  318. """ Creates an AST node for a function call to Fortran's "lbound(...)"
  319. Parameters
  320. ==========
  321. array : Symbol or String
  322. dim : expr
  323. kind : expr
  324. Examples
  325. ========
  326. >>> from sympy import fcode
  327. >>> from sympy.codegen.fnodes import lbound
  328. >>> lb = lbound('arr', dim=2)
  329. >>> fcode(lb, source_format='free')
  330. 'lbound(arr, 2)'
  331. """
  332. return FunctionCall(
  333. 'lbound',
  334. [_printable(array)] +
  335. ([_printable(dim)] if dim else []) +
  336. ([_printable(kind)] if kind else [])
  337. )
  338. def ubound(array, dim=None, kind=None):
  339. return FunctionCall(
  340. 'ubound',
  341. [_printable(array)] +
  342. ([_printable(dim)] if dim else []) +
  343. ([_printable(kind)] if kind else [])
  344. )
  345. def shape(source, kind=None):
  346. """ Creates an AST node for a function call to Fortran's "shape(...)"
  347. Parameters
  348. ==========
  349. source : Symbol or String
  350. kind : expr
  351. Examples
  352. ========
  353. >>> from sympy import fcode
  354. >>> from sympy.codegen.fnodes import shape
  355. >>> shp = shape('x')
  356. >>> fcode(shp, source_format='free')
  357. 'shape(x)'
  358. """
  359. return FunctionCall(
  360. 'shape',
  361. [_printable(source)] +
  362. ([_printable(kind)] if kind else [])
  363. )
  364. def size(array, dim=None, kind=None):
  365. """ Creates an AST node for a function call to Fortran's "size(...)"
  366. Examples
  367. ========
  368. >>> from sympy import fcode, Symbol
  369. >>> from sympy.codegen.ast import FunctionDefinition, real, Return
  370. >>> from sympy.codegen.fnodes import array, sum_, size
  371. >>> a = Symbol('a', real=True)
  372. >>> body = [Return((sum_(a**2)/size(a))**.5)]
  373. >>> arr = array(a, dim=[':'], intent='in')
  374. >>> fd = FunctionDefinition(real, 'rms', [arr], body)
  375. >>> print(fcode(fd, source_format='free', standard=2003))
  376. real*8 function rms(a)
  377. real*8, dimension(:), intent(in) :: a
  378. rms = sqrt(sum(a**2)*1d0/size(a))
  379. end function
  380. """
  381. return FunctionCall(
  382. 'size',
  383. [_printable(array)] +
  384. ([_printable(dim)] if dim else []) +
  385. ([_printable(kind)] if kind else [])
  386. )
  387. def reshape(source, shape, pad=None, order=None):
  388. """ Creates an AST node for a function call to Fortran's "reshape(...)"
  389. Parameters
  390. ==========
  391. source : Symbol or String
  392. shape : ArrayExpr
  393. """
  394. return FunctionCall(
  395. 'reshape',
  396. [_printable(source), _printable(shape)] +
  397. ([_printable(pad)] if pad else []) +
  398. ([_printable(order)] if pad else [])
  399. )
  400. def bind_C(name=None):
  401. """ Creates an Attribute ``bind_C`` with a name.
  402. Parameters
  403. ==========
  404. name : str
  405. Examples
  406. ========
  407. >>> from sympy import fcode, Symbol
  408. >>> from sympy.codegen.ast import FunctionDefinition, real, Return
  409. >>> from sympy.codegen.fnodes import array, sum_, bind_C
  410. >>> a = Symbol('a', real=True)
  411. >>> s = Symbol('s', integer=True)
  412. >>> arr = array(a, dim=[s], intent='in')
  413. >>> body = [Return((sum_(a**2)/s)**.5)]
  414. >>> fd = FunctionDefinition(real, 'rms', [arr, s], body, attrs=[bind_C('rms')])
  415. >>> print(fcode(fd, source_format='free', standard=2003))
  416. real*8 function rms(a, s) bind(C, name="rms")
  417. real*8, dimension(s), intent(in) :: a
  418. integer*4 :: s
  419. rms = sqrt(sum(a**2)/s)
  420. end function
  421. """
  422. return Attribute('bind_C', [String(name)] if name else [])
  423. class GoTo(Token):
  424. """ Represents a goto statement in Fortran
  425. Examples
  426. ========
  427. >>> from sympy.codegen.fnodes import GoTo
  428. >>> go = GoTo([10, 20, 30], 'i')
  429. >>> from sympy import fcode
  430. >>> fcode(go, source_format='free')
  431. 'go to (10, 20, 30), i'
  432. """
  433. __slots__ = _fields = ('labels', 'expr')
  434. defaults = {'expr': none}
  435. _construct_labels = staticmethod(_mk_Tuple)
  436. _construct_expr = staticmethod(sympify)
  437. class FortranReturn(Token):
  438. """ AST node explicitly mapped to a fortran "return".
  439. Explanation
  440. ===========
  441. Because a return statement in fortran is different from C, and
  442. in order to aid reuse of our codegen ASTs the ordinary
  443. ``.codegen.ast.Return`` is interpreted as assignment to
  444. the result variable of the function. If one for some reason needs
  445. to generate a fortran RETURN statement, this node should be used.
  446. Examples
  447. ========
  448. >>> from sympy.codegen.fnodes import FortranReturn
  449. >>> from sympy import fcode
  450. >>> fcode(FortranReturn('x'))
  451. ' return x'
  452. """
  453. __slots__ = _fields = ('return_value',)
  454. defaults = {'return_value': none}
  455. _construct_return_value = staticmethod(sympify)
  456. class FFunction(Function):
  457. _required_standard = 77
  458. def _fcode(self, printer):
  459. name = self.__class__.__name__
  460. if printer._settings['standard'] < self._required_standard:
  461. raise NotImplementedError("%s requires Fortran %d or newer" %
  462. (name, self._required_standard))
  463. return '{}({})'.format(name, ', '.join(map(printer._print, self.args)))
  464. class F95Function(FFunction):
  465. _required_standard = 95
  466. class isign(FFunction):
  467. """ Fortran sign intrinsic for integer arguments. """
  468. nargs = 2
  469. class dsign(FFunction):
  470. """ Fortran sign intrinsic for double precision arguments. """
  471. nargs = 2
  472. class cmplx(FFunction):
  473. """ Fortran complex conversion function. """
  474. nargs = 2 # may be extended to (2, 3) at a later point
  475. class kind(FFunction):
  476. """ Fortran kind function. """
  477. nargs = 1
  478. class merge(F95Function):
  479. """ Fortran merge function """
  480. nargs = 3
  481. class _literal(Float):
  482. _token: str
  483. _decimals: int
  484. def _fcode(self, printer, *args, **kwargs):
  485. mantissa, sgnd_ex = ('%.{}e'.format(self._decimals) % self).split('e')
  486. mantissa = mantissa.strip('0').rstrip('.')
  487. ex_sgn, ex_num = sgnd_ex[0], sgnd_ex[1:].lstrip('0')
  488. ex_sgn = '' if ex_sgn == '+' else ex_sgn
  489. return (mantissa or '0') + self._token + ex_sgn + (ex_num or '0')
  490. class literal_sp(_literal):
  491. """ Fortran single precision real literal """
  492. _token = 'e'
  493. _decimals = 9
  494. class literal_dp(_literal):
  495. """ Fortran double precision real literal """
  496. _token = 'd'
  497. _decimals = 17
  498. class sum_(Token, Expr):
  499. __slots__ = _fields = ('array', 'dim', 'mask')
  500. defaults = {'dim': none, 'mask': none}
  501. _construct_array = staticmethod(sympify)
  502. _construct_dim = staticmethod(sympify)
  503. class product_(Token, Expr):
  504. __slots__ = _fields = ('array', 'dim', 'mask')
  505. defaults = {'dim': none, 'mask': none}
  506. _construct_array = staticmethod(sympify)
  507. _construct_dim = staticmethod(sympify)