transform.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510
  1. # Copyright (c) 2019-2024 by Rocky Bernstein
  2. # This program is free software: you can redistribute it and/or modify
  3. # it under the terms of the GNU General Public License as published by
  4. # the Free Software Foundation, either version 3 of the License, or
  5. # (at your option) any later version.
  6. #
  7. # This program is distributed in the hope that it will be useful,
  8. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. # GNU General Public License for more details.
  11. #
  12. # You should have received a copy of the GNU General Public License
  13. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  14. from copy import copy
  15. from typing import Optional
  16. from spark_parser import GenericASTTraversal, GenericASTTraversalPruningException
  17. from uncompyle6.parsers.treenode import SyntaxTree
  18. from uncompyle6.scanners.tok import NoneToken, Token
  19. from uncompyle6.semantics.consts import ASSIGN_DOC_STRING, RETURN_NONE
  20. from uncompyle6.semantics.helper import find_code_node
  21. from uncompyle6.show import maybe_show_tree
  22. def is_docstring(node, version, co_consts):
  23. if node == "sstmt":
  24. node = node[0]
  25. # TODO: the test below on 2.7 succeeds for
  26. # class OldClass:
  27. # __doc__ = DocDescr()
  28. # which produces:
  29. #
  30. # assign (2)
  31. # 0. expr
  32. # call (2)
  33. # 0. expr
  34. # L. 16 6 LOAD_DEREF 0 'DocDescr'
  35. # 1. 9 CALL_FUNCTION_0 0 None
  36. # 1. store
  37. #
  38. # See Python 2.7 test_descr.py
  39. # If ASSIGN_DOC_STRING doesn't work we need something like the below
  40. # but more elaborate to address the above.
  41. # try:
  42. # return node.kind == "assign" and node[1][0].pattr == "__doc__"
  43. # except:
  44. # return False
  45. if version <= (2, 7):
  46. doc_load = "LOAD_CONST"
  47. else:
  48. doc_load = "LOAD_STR"
  49. return node == ASSIGN_DOC_STRING(co_consts[0], doc_load)
  50. def is_not_docstring(call_stmt_node) -> bool:
  51. try:
  52. return (
  53. call_stmt_node == "call_stmt"
  54. and call_stmt_node[0][0] == "LOAD_STR"
  55. and call_stmt_node[1] == "POP_TOP"
  56. )
  57. except Exception:
  58. return False
  59. class TreeTransform(GenericASTTraversal, object):
  60. def __init__(
  61. self,
  62. version: tuple,
  63. is_pypy=False,
  64. show_ast: Optional[dict] = None,
  65. ):
  66. self.version = version
  67. self.showast = show_ast
  68. self.is_pypy = is_pypy
  69. return
  70. def maybe_show_tree(self, tree):
  71. if isinstance(self.showast, dict) and (
  72. self.showast.get("before") or self.showast.get("after")
  73. ):
  74. maybe_show_tree(self, tree)
  75. def preorder(self, node=None):
  76. """Walk the tree in roughly 'preorder' (a bit of a lie explained below).
  77. For each node with typestring name *name* if the
  78. node has a method called n_*name*, call that before walking
  79. children.
  80. In typical use a node with children can call "preorder" in any
  81. order it wants which may skip children or order then in ways
  82. other than first to last. In fact, this this happens. So in
  83. this sense this function not strictly preorder.
  84. """
  85. if node is None:
  86. node = self.ast
  87. try:
  88. name = "n_" + self.typestring(node)
  89. if hasattr(self, name):
  90. func = getattr(self, name)
  91. node = func(node)
  92. except GenericASTTraversalPruningException:
  93. return
  94. for i, kid in enumerate(node):
  95. node[i] = self.preorder(kid)
  96. return node
  97. def n_mkfunc(self, node):
  98. """If the function has a docstring (this is found in the code
  99. constants), pull that out and make it part of the syntax
  100. tree. When generating the source string that AST node rather
  101. than the code field is seen and used.
  102. """
  103. if self.version >= (3, 7):
  104. code_index = -3
  105. else:
  106. code_index = -2
  107. code = find_code_node(node, code_index).attr
  108. mkfunc_pattr = node[-1].pattr
  109. if isinstance(mkfunc_pattr, tuple):
  110. assert isinstance(mkfunc_pattr, tuple)
  111. assert len(mkfunc_pattr) == 4 and isinstance(mkfunc_pattr, int)
  112. if len(code.co_consts) > 0 and isinstance(code.co_consts[0], str):
  113. docstring_node = SyntaxTree(
  114. "docstring", [Token("LOAD_STR", has_arg=True, pattr=code.co_consts[0])]
  115. )
  116. docstring_node.transformed_by = "n_mkfunc"
  117. node = SyntaxTree("mkfunc", node[:-1] + [docstring_node, node[-1]])
  118. node.transformed_by = "n_mkfunc"
  119. return node
  120. def n_ifstmt(self, node):
  121. """Here we check if we can turn an `ifstmt` or 'iflaststmtl` into
  122. some kind of `assert` statement"""
  123. testexpr = node[0]
  124. if testexpr not in ("testexpr", "testexprl"):
  125. return node
  126. if node.kind in ("ifstmt", "ifstmtl"):
  127. ifstmts_jump = node[1]
  128. if ifstmts_jump == "_ifstmts_jumpl" and ifstmts_jump[0] == "_ifstmts_jump":
  129. ifstmts_jump = ifstmts_jump[0]
  130. elif ifstmts_jump not in (
  131. "_ifstmts_jump",
  132. "_ifstmts_jumpl",
  133. "ifstmts_jumpl",
  134. ):
  135. return node
  136. stmts = ifstmts_jump[0]
  137. else:
  138. # iflaststmtl works this way
  139. stmts = node[1]
  140. if stmts in ("c_stmts", "stmts", "stmts_opt") and len(stmts) == 1:
  141. raise_stmt = stmts[0]
  142. if raise_stmt != "raise_stmt1" and len(raise_stmt) > 0:
  143. raise_stmt = raise_stmt[0]
  144. testtrue_or_false = testexpr[0]
  145. if (
  146. raise_stmt.kind == "raise_stmt1"
  147. and 1 <= len(testtrue_or_false) <= 2
  148. and raise_stmt.first_child().pattr == "AssertionError"
  149. ):
  150. if testtrue_or_false in ("testtrue", "testtruel"):
  151. # Skip over the testtrue because because it would
  152. # produce a "not" and we don't want that here.
  153. assert_expr = testtrue_or_false[0]
  154. jump_cond = NoneToken
  155. else:
  156. assert testtrue_or_false in ("testfalse", "testfalsel")
  157. assert_expr = testtrue_or_false[0]
  158. if assert_expr in ("testfalse_not_and", "and_not"):
  159. # FIXME: come back to stuff like this
  160. return node
  161. jump_cond = testtrue_or_false[1]
  162. assert_expr.kind = "assert_expr"
  163. pass
  164. expr = raise_stmt[0]
  165. RAISE_VARARGS_1 = raise_stmt[1]
  166. call = expr[0]
  167. if call == "call":
  168. # ifstmt
  169. # 0. testexpr
  170. # testtrue (2)
  171. # 0. expr
  172. # 1. _ifstmts_jump (2)
  173. # 0. c_stmts
  174. # stmt
  175. # raise_stmt1 (2)
  176. # 0. expr
  177. # call (3)
  178. # 1. RAISE_VARARGS_1
  179. # becomes:
  180. # assert2 ::= assert_expr jmp_true LOAD_ASSERT expr RAISE_VARARGS_1 COME_FROM
  181. if jump_cond in ("jmp_true", NoneToken):
  182. kind = "assert2"
  183. else:
  184. if jump_cond == "jmp_false":
  185. # FIXME: We don't handle this kind of thing yet.
  186. return node
  187. kind = "assert2not"
  188. LOAD_ASSERT = call[0].first_child()
  189. if LOAD_ASSERT not in ("LOAD_ASSERT", "LOAD_GLOBAL"):
  190. return node
  191. if isinstance(call[1], SyntaxTree):
  192. expr = call[1][0]
  193. assert_expr.transformed_by = "n_ifstmt"
  194. node = SyntaxTree(
  195. kind,
  196. [
  197. assert_expr,
  198. jump_cond,
  199. LOAD_ASSERT,
  200. expr,
  201. RAISE_VARARGS_1,
  202. ],
  203. transformed_by="n_ifstmt",
  204. )
  205. pass
  206. pass
  207. else:
  208. # ifstmt
  209. # 0. testexpr (2)
  210. # testtrue
  211. # 0. expr
  212. # 1. _ifstmts_jump (2)
  213. # 0. c_stmts
  214. # stmts
  215. # raise_stmt1 (2)
  216. # 0. expr
  217. # LOAD_ASSERT
  218. # 1. RAISE_VARARGS_1
  219. # becomes:
  220. # assert ::= assert_expr jmp_true LOAD_ASSERT RAISE_VARARGS_1 COME_FROM
  221. if jump_cond in ("jmp_true", NoneToken):
  222. if self.is_pypy:
  223. kind = "assert0_pypy"
  224. else:
  225. kind = "assert"
  226. else:
  227. assert jump_cond == "jmp_false"
  228. kind = "assertnot"
  229. LOAD_ASSERT = expr[0]
  230. node = SyntaxTree(
  231. kind,
  232. [assert_expr, jump_cond, LOAD_ASSERT, RAISE_VARARGS_1],
  233. transformed_by="n_ifstmt",
  234. )
  235. pass
  236. pass
  237. return node
  238. n_ifstmtl = n_iflaststmtl = n_ifstmt
  239. # preprocess is used for handling chains of
  240. # if elif elif
  241. def n_ifelsestmt(self, node, preprocess=False):
  242. """
  243. Transformation involving if..else statements.
  244. For example
  245. if ...
  246. else
  247. if ..
  248. into:
  249. if ..
  250. elif ...
  251. [else ...]
  252. where appropriate.
  253. """
  254. else_suite = node[3]
  255. n = else_suite[0]
  256. old_stmts = None
  257. else_suite_index = 1
  258. len_n = len(n)
  259. # Sometimes stmt is reduced away and n[0] can be a single reduction like continue -> CONTINUE.
  260. if (
  261. len_n == 1
  262. and isinstance(n[0], SyntaxTree)
  263. and len(n[0]) == 1
  264. and n[0] == "stmt"
  265. ):
  266. n = n[0][0]
  267. elif len_n == 0:
  268. return node
  269. if n[0].kind in ("lastc_stmt", "lastl_stmt"):
  270. n = n[0]
  271. if n[0].kind in (
  272. "ifstmt",
  273. "iflaststmt",
  274. "iflaststmtl",
  275. "ifelsestmtl",
  276. "ifelsestmtc",
  277. "ifpoplaststmtl",
  278. ):
  279. n = n[0]
  280. if n.kind == "ifpoplaststmtl":
  281. old_stmts = n[2]
  282. else_suite_index = 2
  283. pass
  284. else:
  285. if (
  286. len_n > 1
  287. and isinstance(n[0], SyntaxTree)
  288. and 1 == len(n[0])
  289. and n[0] == "stmt"
  290. and n[1].kind == "stmt"
  291. ):
  292. else_suite_stmts = n[0]
  293. elif len_n == 1:
  294. else_suite_stmts = n
  295. else:
  296. return node
  297. if else_suite_stmts[0].kind in (
  298. "ifstmt",
  299. "iflaststmt",
  300. "ifelsestmt",
  301. "ifelsestmtl",
  302. ):
  303. old_stmts = n
  304. n = else_suite_stmts[0]
  305. else:
  306. return node
  307. if n.kind in ("ifstmt", "iflaststmt", "iflaststmtl", "ifpoplaststmtl"):
  308. node.kind = "ifelifstmt"
  309. n.kind = "elifstmt"
  310. elif n.kind in ("ifelsestmtr",):
  311. node.kind = "ifelifstmt"
  312. n.kind = "elifelsestmtr"
  313. elif n.kind in ("ifelsestmt", "ifelsestmtc", "ifelsestmtl"):
  314. node.kind = "ifelifstmt"
  315. self.n_ifelsestmt(n, preprocess=True)
  316. if n == "ifelifstmt":
  317. n.kind = "elifelifstmt"
  318. elif n.kind in ("ifelsestmt", "ifelsestmtc", "ifelsestmtl"):
  319. n.kind = "elifelsestmt"
  320. if not preprocess:
  321. if old_stmts:
  322. if n.kind == "elifstmt":
  323. trailing_else = SyntaxTree("stmts", old_stmts[1:])
  324. if len(trailing_else):
  325. # We use elifelsestmtr because it has 3 nodes
  326. elifelse_stmt = SyntaxTree(
  327. "elifelsestmtr", [n[0], n[else_suite_index], trailing_else]
  328. )
  329. node[3] = elifelse_stmt
  330. else:
  331. elif_stmt = SyntaxTree("elifstmt", [n[0], n[else_suite_index]])
  332. node[3] = elif_stmt
  333. node.transformed_by = "n_ifelsestmt"
  334. pass
  335. else:
  336. # Other cases for n.kind may happen here
  337. pass
  338. pass
  339. return node
  340. n_ifelsestmtc = n_ifelsestmtl = n_ifelsestmt
  341. def n_import_from37(self, node):
  342. importlist37 = node[3]
  343. if importlist37 != "importlist37":
  344. return node
  345. if len(importlist37) == 1 and importlist37 == "importlist37":
  346. alias37 = importlist37[0]
  347. store = alias37[1]
  348. assert store == "store"
  349. alias_name = store[0].attr
  350. import_name_attr = node[2]
  351. assert import_name_attr == "IMPORT_NAME_ATTR"
  352. dotted_names = import_name_attr.attr.split(".")
  353. if len(dotted_names) > 1 and dotted_names[-1] == alias_name:
  354. # Simulate:
  355. # Instead of
  356. # import_from37 ::= LOAD_CONST LOAD_CONST IMPORT_NAME_ATTR importlist37 POP_TOP
  357. # import_as37 ::= LOAD_CONST LOAD_CONST importlist37 store POP_TOP
  358. # 'import_as37': ( '%|import %c as %c\n', 2, -2),
  359. node = SyntaxTree(
  360. "import_as37",
  361. [node[0], node[1], import_name_attr, store, node[-1]],
  362. transformed_by="n_import_from37",
  363. )
  364. pass
  365. pass
  366. return node
  367. def n_list_for(self, list_for_node):
  368. expr = list_for_node[0]
  369. if expr == "expr" and expr[0] == "get_iter":
  370. # Remove extraneous get_iter() inside the "for" of a comprehension
  371. assert expr[0][0] == "expr"
  372. list_for_node[0] = expr[0][0]
  373. list_for_node.transformed_by = ("n_list_for",)
  374. return list_for_node
  375. def n_negated_testtrue(self, node):
  376. assert node[0] == "testtrue"
  377. test_node = node[0][0]
  378. test_node.transformed_by = "n_negated_testtrue"
  379. return test_node
  380. def n_stmts(self, node):
  381. if node.first_child() == "SETUP_ANNOTATIONS":
  382. prev = node[0][0]
  383. new_stmts = [node[0]]
  384. for i, sstmt in enumerate(node[1:]):
  385. ann_assign = sstmt[0]
  386. if ann_assign == "ann_assign" and prev == "assign":
  387. annotate_var = ann_assign[-2]
  388. if annotate_var.attr == prev[-1][0].attr:
  389. node[i].kind = "deleted " + node[i].kind
  390. del new_stmts[-1]
  391. ann_assign_init = SyntaxTree(
  392. "ann_assign_init",
  393. [ann_assign[0], copy(prev[0]), annotate_var],
  394. )
  395. if sstmt[0] == "ann_assign":
  396. sstmt[0] = ann_assign_init
  397. else:
  398. sstmt[0][0] = ann_assign_init
  399. sstmt[0].transformed_by = "n_stmts"
  400. pass
  401. pass
  402. new_stmts.append(sstmt)
  403. prev = ann_assign
  404. pass
  405. node.data = new_stmts
  406. return node
  407. def traverse(self, node, is_lambda=False):
  408. node = self.preorder(node)
  409. return node
  410. def transform(self, parse_tree: GenericASTTraversal, code) -> GenericASTTraversal:
  411. self.maybe_show_tree(parse_tree)
  412. self.ast = copy(parse_tree)
  413. del parse_tree
  414. self.ast = self.traverse(self.ast, is_lambda=False)
  415. n = len(self.ast)
  416. try:
  417. # Disambiguate a string (expression) which appears as a "call_stmt" at
  418. # the beginning of a function versus a docstring. Seems pretty academic,
  419. # but this is Python.
  420. call_stmt = self.ast[0][0]
  421. if is_not_docstring(call_stmt):
  422. call_stmt.kind = "string_at_beginning"
  423. call_stmt.transformed_by = "transform"
  424. pass
  425. except Exception:
  426. pass
  427. try:
  428. for i in range(n):
  429. sstmt = self.ast[i]
  430. if len(sstmt) == 1 and sstmt == "sstmt":
  431. self.ast[i] = self.ast[i][0]
  432. if is_docstring(self.ast[i], self.version, code.co_consts):
  433. load_const = copy(self.ast[i].first_child())
  434. store_name = copy(self.ast[i].last_child())
  435. docstring_ast = SyntaxTree("docstring", [load_const, store_name])
  436. docstring_ast.transformed_by = "transform"
  437. del self.ast[i]
  438. self.ast.insert(0, docstring_ast)
  439. break
  440. if self.ast[-1] == RETURN_NONE:
  441. self.ast.pop() # remove last node
  442. # todo: if empty, add 'pass'
  443. except Exception:
  444. pass
  445. return self.ast
  446. # Write template_engine
  447. # def template_engine