transform.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. # Copyright (c) 2019-2020, 2022-2024 by Rocky Bernstein
  2. # This program is free software: you can redistribute it and/or modify
  3. # it under the terms of the GNU General Public License as published by
  4. # the Free Software Foundation, either version 3 of the License, or
  5. # (at your option) any later version.
  6. #
  7. # This program is distributed in the hope that it will be useful,
  8. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. # GNU General Public License for more details.
  11. #
  12. # You should have received a copy of the GNU General Public License
  13. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  14. from copy import copy
  15. from typing import Callable, Optional
  16. from spark_parser import GenericASTTraversal, GenericASTTraversalPruningException
  17. from decompyle3.parsers.treenode import SyntaxTree
  18. from decompyle3.scanners.tok import NoneToken, Token
  19. from decompyle3.semantics.consts import ASSIGN_DOC_STRING, RETURN_NONE
  20. from decompyle3.semantics.helper import find_code_node
  21. from decompyle3.show import maybe_show_tree
  22. # Eventually we won't need STRIPPED_NODES because all semantic
  23. # actions will have been converted to the new form. So here, we will
  24. # do everything by default.
  25. STRIPPED_NODES = (
  26. "_come_froms",
  27. # "and_or",
  28. # "and_or_expr1",
  29. # "and_or_parts",
  30. # "and_part",
  31. # "and1",
  32. # "async_for_loop",
  33. # "async_iter",
  34. # "branch_op",
  35. "bb_start_opt",
  36. "come_froms",
  37. # "comp_if",
  38. # "comp_iter",
  39. # "comp_iter_outer",
  40. # "compare_chained",
  41. # "compare_chained_return",
  42. "compare_chained_middle",
  43. "compare_chained_middle_return",
  44. "compare_chained_right",
  45. "compare_chained_right_return",
  46. # "dict_comp_func",
  47. "ending_return",
  48. # "expr_pjif",
  49. # "expr_pjit",
  50. # "for_jump_pop_iff",
  51. # "for_jump_unconditional",
  52. # "for_loop",
  53. "forelsestmt38",
  54. # "genexpr_func",
  55. "genexpr_func_async",
  56. "if_exp_compare38",
  57. # "jifop",
  58. # "jitop",
  59. # "or",
  60. # "or_parts_pjit",
  61. # "pjump_iff_loop",
  62. # "return_expr",
  63. # "set_comp_func",
  64. "tryfinallystmt",
  65. "with_as",
  66. )
  67. def is_docstring(node, co_consts) -> bool:
  68. # try:
  69. # return node.kind == "assign" and node[1][0].pattr == "__doc__"
  70. # except:
  71. # return False
  72. return node == ASSIGN_DOC_STRING(co_consts[0], "LOAD_STR")
  73. def is_not_docstring(call_stmt_node) -> bool:
  74. try:
  75. return (
  76. call_stmt_node == "call_stmt"
  77. and call_stmt_node[0][0] == "LOAD_STR"
  78. and call_stmt_node[1] == "POP_TOP"
  79. )
  80. except Exception:
  81. return False
  82. class TreeTransform(GenericASTTraversal, object):
  83. def __init__(
  84. self,
  85. version: tuple,
  86. str_with_template: Callable,
  87. show_ast: Optional[dict] = None,
  88. ):
  89. self.showast = show_ast
  90. self.version = version
  91. self.str_with_template_for_later = str_with_template
  92. self.str_with_template = None
  93. return
  94. def maybe_show_tree(self, tree, phase: str, print_fn: Callable):
  95. if phase == "before":
  96. phase_name = "parse tree"
  97. else:
  98. phase_name = "transformed abstract tree"
  99. self.str_with_template = self.str_with_template_for_later
  100. if isinstance(self.showast, dict) and self.showast.get(phase, False):
  101. print_fn(f"""\n# ---- {phase_name}:\n """)
  102. maybe_show_tree(self, tree)
  103. def preorder(self, node=None):
  104. """Walk the tree in roughly 'preorder' (a bit of a lie explained below).
  105. For each node with typestring name *name* if the
  106. node has a method called n_*name*, call that before walking
  107. children.
  108. In typical use a node with children can call "preorder" in any
  109. order it wants which may skip children or order then in ways
  110. other than first to last. In fact, this this happens. So in
  111. this sense this function not strictly preorder.
  112. """
  113. if node is None:
  114. node = self.ast
  115. try:
  116. name = "n_" + self.typestring(node)
  117. if hasattr(self, name):
  118. func = getattr(self, name)
  119. node = func(node)
  120. except GenericASTTraversalPruningException:
  121. return
  122. if node.kind in STRIPPED_NODES:
  123. return self.strip_pseudo_ops(node)
  124. for i, kid in enumerate(node):
  125. node[i] = self.preorder(kid)
  126. return node
  127. def n_await_expr(self, node):
  128. """Here we check for await(await)"""
  129. expr = node[0]
  130. assert expr == "expr"
  131. if expr[0] == "await_expr":
  132. expr[0].transformed_by = "n_await_expr"
  133. return expr[0]
  134. return node
  135. def n_mkfunc(self, node):
  136. """If the function has a docstring (this is found in the code
  137. constants), pull that out and make it part of the syntax
  138. tree. When generating the source string that AST node rather
  139. than the code field is seen and used.
  140. """
  141. code = find_code_node(node, -3).attr
  142. mkfunc_pattr = node[-1].pattr
  143. if isinstance(mkfunc_pattr, tuple):
  144. assert isinstance(mkfunc_pattr, tuple)
  145. assert len(mkfunc_pattr) == 4 and isinstance(mkfunc_pattr, int)
  146. if len(code.co_consts) > 0 and isinstance(code.co_consts[0], str):
  147. docstring_node = SyntaxTree(
  148. "docstring",
  149. [Token("LOAD_STR", has_arg=True, pattr=code.co_consts[0])],
  150. transformed_by="n_mkfunc",
  151. )
  152. node = SyntaxTree(
  153. "mkfunc",
  154. node[:-1] + [docstring_node, node[-1]],
  155. transformed_by="n_mkfunc",
  156. )
  157. return node
  158. def n_ifstmt(self, node):
  159. """Here we check if we can turn an `ifstmt` or 'iflaststmtc` into
  160. some kind of `assert` statement.
  161. Also:
  162. if or_in_ifexp ifstmts_jump becomes:
  163. if "not_or ifstmts_jump
  164. """
  165. testexpr = node[0]
  166. if testexpr not in ("testexpr", "testexprc"):
  167. return node
  168. if node.kind in ("ifstmt", "ifstmtc"):
  169. stmts = None
  170. ifstmts_jump = node[1]
  171. if ifstmts_jump == "ifstmts_jump":
  172. testtrue = copy(testexpr[0])
  173. if testtrue == "testtrue" and testtrue[0] == "or_in_ifexp":
  174. testfalse = copy(testtrue)
  175. testfalse.kind = "testfalse"
  176. testfalse[0].kind = "or_not"
  177. node = SyntaxTree(
  178. "if_or_not_stmt",
  179. [testfalse, ifstmts_jump],
  180. transformed_by="n_ifstmt",
  181. )
  182. return node
  183. if ifstmts_jump == "ifstmts_jumpc" and ifstmts_jump[0] == "ifstmts_jump":
  184. ifstmts_jump = ifstmts_jump[0]
  185. elif ifstmts_jump in ("stmts",):
  186. stmts = node[1]
  187. elif ifstmts_jump not in ("ifstmts_jump", "ifstmts_jumpc"):
  188. return node
  189. if stmts is None:
  190. stmts = ifstmts_jump[0]
  191. else:
  192. # iflaststmt{c,} works this way
  193. stmts = node[1]
  194. if stmts in ("c_stmts", "stmts", "stmts_opt") and len(stmts) == 1:
  195. raise_stmt = stmts[0]
  196. if raise_stmt != "raise_stmt1" and len(raise_stmt) > 0:
  197. raise_stmt = raise_stmt[0]
  198. testtrue_or_false = testexpr[0]
  199. if testtrue_or_false == "testexpr":
  200. testtrue_or_false = testtrue_or_false[0]
  201. if (
  202. raise_stmt == "raise_stmt1"
  203. and 1 <= len(testtrue_or_false) <= 2
  204. and raise_stmt.first_child().pattr == "AssertionError"
  205. ):
  206. if testtrue_or_false in ("testtrue", "testtruec"):
  207. # Skip over the testtrue because because it would
  208. # produce a "not" and we don't want that here.
  209. assert_expr = testtrue_or_false[0]
  210. jump_cond = NoneToken
  211. else:
  212. assert testtrue_or_false in (
  213. "testfalse",
  214. "testfalsec",
  215. ), testtrue_or_false
  216. assert_expr = testtrue_or_false[0]
  217. if assert_expr in ("and_not", "nand", "not_or", "and"):
  218. # FIXME: come back to stuff like this
  219. return node
  220. if testtrue_or_false[0] == "expr_pjif":
  221. jump_cond = testtrue_or_false[0][1]
  222. else:
  223. jump_cond = testtrue_or_false[1]
  224. assert_expr.kind = "assert_expr"
  225. pass
  226. expr = raise_stmt[0]
  227. RAISE_VARARGS_1 = raise_stmt[1]
  228. call = expr[0]
  229. if call == "call":
  230. # ifstmt
  231. # 0. testexpr
  232. # testtrue (2)
  233. # 0. expr
  234. # 1. _ifstmts_jump (2)
  235. # 0. c_stmts
  236. # stmt
  237. # raise_stmt1 (2)
  238. # 0. expr
  239. # call (3)
  240. # 1. RAISE_VARARGS_1
  241. # becomes:
  242. # assert2 ::= assert_expr POP_JUMP_IF_TRUE LOAD_ASSERT expr
  243. # RAISE_VARARGS_1 COME_FROM
  244. if jump_cond in ("POP_JUMP_IF_TRUE", NoneToken):
  245. kind = "assert2"
  246. else:
  247. if jump_cond == "POP_JUMP_IF_FALSE":
  248. # FIXME: We don't handle this kind of thing yet.
  249. return node
  250. kind = "assert2not"
  251. LOAD_ASSERT = call[0].first_child()
  252. if LOAD_ASSERT not in ("LOAD_ASSERT", "LOAD_GLOBAL"):
  253. return node
  254. if isinstance(call[1], SyntaxTree):
  255. expr = call[1][0]
  256. assert_expr.transformed_by = "n_ifstmt"
  257. node = SyntaxTree(
  258. kind,
  259. [
  260. assert_expr,
  261. jump_cond,
  262. LOAD_ASSERT,
  263. expr,
  264. RAISE_VARARGS_1,
  265. ],
  266. transformed_by="n_ifstmt",
  267. )
  268. pass
  269. pass
  270. else:
  271. # ifstmt
  272. # 0. testexpr (2)
  273. # testtrue
  274. # 0. expr
  275. # 1. _ifstmts_jump (2)
  276. # 0. c_stmts
  277. # stmts
  278. # raise_stmt1 (2)
  279. # 0. expr
  280. # LOAD_ASSERT
  281. # 1. RAISE_VARARGS_1
  282. # becomes:
  283. # assert ::= assert_expr POP_JUMP_IF_TRUE LOAD_ASSERT
  284. # RAISE_VARARGS_1 COME_FROM
  285. assert_expr.transformed_by = "n_ifstmt"
  286. if jump_cond in (
  287. "POP_JUMP_IF_TRUE",
  288. "POP_JUMP_IF_TRUE_LOOP",
  289. NoneToken,
  290. ):
  291. kind = "assert"
  292. else:
  293. assert jump_cond.kind.startswith("POP_JUMP_IF_")
  294. kind = "assertnot"
  295. LOAD_ASSERT = expr[0]
  296. node = SyntaxTree(
  297. kind,
  298. [assert_expr, jump_cond, LOAD_ASSERT, RAISE_VARARGS_1],
  299. transformed_by="n_ifstmt",
  300. )
  301. pass
  302. pass
  303. return node
  304. n_ifstmtc = n_iflaststmtc = n_iflaststmt = n_ifstmt
  305. # preprocess is used for handling chains of
  306. # if elif elif
  307. def n_ifelsestmt(self, node, preprocess=False):
  308. """
  309. Transformation involving if..else statements.
  310. For example
  311. if ...
  312. else
  313. if ..
  314. into:
  315. if ..
  316. elif ...
  317. [else ...]
  318. where appropriate.
  319. """
  320. else_suite = node[3]
  321. n = else_suite[0]
  322. old_stmts = None
  323. else_suite_index = 1
  324. if len(n) and n[0] == "suite_stmts":
  325. n = n[0]
  326. len_n = len(n)
  327. if len_n == 1 == len(n[0]) and n[0] in ("c_stmt", "stmt", "stmts"):
  328. n = n[0][0]
  329. elif len_n == 0:
  330. return node
  331. elif n[0].kind in ("lastc_stmt",):
  332. n = n[0]
  333. if n[0].kind in (
  334. "ifstmt",
  335. "iflaststmt",
  336. "iflaststmtc",
  337. "ifelsestmtc",
  338. "ifpoplaststmtc",
  339. ):
  340. n = n[0]
  341. if n.kind == "ifpoplaststmtc":
  342. old_stmts = n[2]
  343. else_suite_index = 2
  344. pass
  345. pass
  346. else:
  347. while n[0].kind in ("_stmts", "c_stmts", "stmts"):
  348. n = n[0]
  349. len_n = len(n)
  350. if (
  351. len_n > 1
  352. and isinstance(n[0], SyntaxTree)
  353. and 1 == len(n[0])
  354. and n[0] in ("c_stmt", "stmt")
  355. and n[1].kind in ("c_stmt", "stmt")
  356. ):
  357. else_suite_stmts = n[0]
  358. elif len_n == 1:
  359. else_suite_stmts = n
  360. else:
  361. return node
  362. if else_suite_stmts[0].kind in (
  363. "ifstmt",
  364. "iflaststmt",
  365. "ifelsestmt",
  366. "ifelsestmtl",
  367. "ifelsestmtc",
  368. ):
  369. old_stmts = n
  370. n = else_suite_stmts[0]
  371. else:
  372. return node
  373. if n.kind == "last_stmt":
  374. n = n[0]
  375. if n.kind in ("ifstmt", "iflaststmt", "iflaststmtc", "ifpoplaststmtc"):
  376. node.kind = "ifelifstmt"
  377. n.kind = "elifstmt"
  378. elif n.kind in ("ifelsestmtr",):
  379. node.kind = "ifelifstmt"
  380. n.kind = "elifelsestmtr"
  381. elif n.kind in ("ifelsestmt", "ifelsestmtc", "ifelsestmtc"):
  382. node.kind = "ifelifstmt"
  383. self.n_ifelsestmt(n, preprocess=True)
  384. if n == "ifelifstmt":
  385. n.kind = "elifelifstmt"
  386. elif n.kind in ("ifelsestmt", "ifelsestmtc"):
  387. n.kind = "elifelsestmt"
  388. if not preprocess:
  389. if old_stmts:
  390. if n.kind == "elifstmt":
  391. trailing_else = SyntaxTree("stmts", old_stmts[1:])
  392. if len(trailing_else):
  393. # We use elifelsestmtr because it has 3 nodes
  394. elifelse_stmt = SyntaxTree(
  395. "elifelsestmtr", [n[0], n[else_suite_index], trailing_else]
  396. )
  397. node[3] = elifelse_stmt
  398. else:
  399. elif_stmt = SyntaxTree("elifstmt", [n[0], n[else_suite_index]])
  400. node[3] = elif_stmt
  401. node.transformed_by = "n_ifelsestmt"
  402. pass
  403. else:
  404. # Other cases for n.kind may happen here
  405. pass
  406. pass
  407. return node
  408. n_ifelsestmtc = n_ifelsestmt
  409. def n_import_from37(self, node):
  410. importlist37 = node[3]
  411. if len(importlist37) == 1 and importlist37 == "importlist37":
  412. alias37 = importlist37[0]
  413. store = alias37[1]
  414. assert store == "store"
  415. alias_name = store[0].attr
  416. import_name_attr = node[2]
  417. assert import_name_attr == "IMPORT_NAME_ATTR"
  418. dotted_names = import_name_attr.attr.split(".")
  419. if len(dotted_names) > 1 and dotted_names[-1] == alias_name:
  420. # Simulate:
  421. # Instead of
  422. # import_from37 ::= LOAD_CONST LOAD_CONST IMPORT_NAME_ATTR importlist37 POP_TOP
  423. # import_as37 ::= LOAD_CONST LOAD_CONST importlist37 store POP_TOP
  424. # 'import_as37': ( '%|import %c as %c\n', 2, -2),
  425. node = SyntaxTree(
  426. "import_as37",
  427. [node[0], node[1], import_name_attr, store, node[-1]],
  428. transformed_by="n_import_from37",
  429. )
  430. pass
  431. pass
  432. return node
  433. def n_list_for(self, list_for_node):
  434. expr = list_for_node[0]
  435. if expr == "expr" and expr[0] == "get_iter":
  436. # Remove extraneous get_iter() inside the "for" of a comprehension
  437. assert expr[0][0] == "expr"
  438. list_for_node[0] = expr[0][0]
  439. list_for_node.transformed_by = ("n_list_for",)
  440. return list_for_node
  441. def n_negated_testtrue(self, node):
  442. assert node[0] == "testtrue"
  443. test_node = node[0][0]
  444. test_node.transformed_by = "n_negated_testtrue"
  445. return test_node
  446. def n_stmts(self, node):
  447. if node.first_child() == "SETUP_ANNOTATIONS":
  448. prev = node[0]
  449. new_stmts = [node[0]]
  450. for i, sstmt in enumerate(node[1:]):
  451. ann_assign = sstmt
  452. if ann_assign == "ann_assign" and prev == "assign":
  453. annotate_var = ann_assign[-2]
  454. if annotate_var.attr == prev[-1][0].attr:
  455. node[i].kind = "deleted " + node[i].kind
  456. del new_stmts[-1]
  457. sstmt = SyntaxTree(
  458. "ann_assign_init",
  459. [ann_assign[0], prev[0], annotate_var],
  460. transformed_by="n_stmts",
  461. )
  462. pass
  463. pass
  464. new_stmts.append(sstmt)
  465. prev = ann_assign
  466. pass
  467. node.data = new_stmts
  468. return node
  469. def traverse(self, node):
  470. node = self.preorder(node)
  471. return node
  472. def transform(
  473. self, parse_tree: GenericASTTraversal, code, print_fn: Callable
  474. ) -> GenericASTTraversal:
  475. self.maybe_show_tree(parse_tree, "before", print_fn)
  476. self.ast = copy(parse_tree)
  477. del parse_tree
  478. self.ast = self.traverse(self.ast)
  479. n = len(self.ast)
  480. try:
  481. # Disambiguate a string (expression) which appears as a "call_stmt" at
  482. # the beginning of a function versus a docstring. Seems pretty academic,
  483. # but this is Python.
  484. call_stmt = self.ast[0][0]
  485. if is_not_docstring(call_stmt):
  486. call_stmt.kind = "string_at_beginning"
  487. call_stmt.transformed_by = "transform"
  488. pass
  489. except Exception:
  490. pass
  491. try:
  492. for i in range(n):
  493. if is_docstring(self.ast[i], code.co_consts):
  494. load_const = copy(self.ast[i].first_child())
  495. store_name = copy(self.ast[i].last_child())
  496. docstring_ast = SyntaxTree(
  497. "docstring",
  498. [load_const, store_name],
  499. transformed_by="transform",
  500. )
  501. del self.ast[i]
  502. self.ast.insert(0, docstring_ast)
  503. break
  504. if self.ast[-1] == RETURN_NONE:
  505. self.ast.pop() # remove last node
  506. # todo: if empty, add 'pass'
  507. except Exception:
  508. pass
  509. self.maybe_show_tree(self.ast, "after", print_fn)
  510. return self.ast
  511. # Write template_engine
  512. # def template_engine
  513. def strip_pseudo_ops(self, node: SyntaxTree) -> SyntaxTree:
  514. new_node = SyntaxTree(node.kind)
  515. for i, kid in enumerate(node):
  516. if hasattr(kid, "optype") and kid.optype == "pseudo":
  517. continue
  518. new_kid = self.preorder(kid)
  519. new_node.data.append(new_kid)
  520. del node
  521. return new_node