helper.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. import sys
  2. from xdis import iscode
  3. from decompyle3.parsers.treenode import SyntaxTree
  4. minint = -sys.maxsize - 1
  5. maxint = sys.maxsize
  6. read_write_global_ops = frozenset(("STORE_GLOBAL", "DELETE_GLOBAL", "LOAD_GLOBAL"))
  7. read_global_ops = frozenset(("STORE_GLOBAL", "DELETE_GLOBAL"))
  8. # NOTE: we also need to check that the variable name is a free variable, not a cell variable.
  9. nonglobal_ops = frozenset(("STORE_DEREF", "DELETE_DEREF"))
  10. def escape_string(s: str, quotes=('"', "'", '"""', "'''")):
  11. quote = None
  12. for q in quotes:
  13. if s.find(q) == -1:
  14. quote = q
  15. break
  16. pass
  17. if quote is None:
  18. quote = '"""'
  19. s = s.replace('"""', '\\"""')
  20. for (orig, replace) in (("\t", "\\t"), ("\n", "\\n"), ("\r", "\\r")):
  21. s = s.replace(orig, replace)
  22. return "%s%s%s" % (quote, s, quote)
  23. # FIXME: This and find_globals could be parameterized with one of the
  24. # above global ops.
  25. def find_all_globals(node, globs):
  26. """Search Syntax Tree node to find variable names that are global."""
  27. for n in node:
  28. if isinstance(n, SyntaxTree):
  29. globs = find_all_globals(n, globs)
  30. elif n.kind in read_write_global_ops:
  31. globs.add(n.pattr)
  32. return globs
  33. def find_code_node(node, start: int):
  34. for i in range(-start, len(node) + 1):
  35. if node[-i].kind == "LOAD_CODE":
  36. code_node = node[-i]
  37. assert iscode(code_node.attr)
  38. return code_node
  39. pass
  40. assert False, "did not find code node starting at %d in %s" % (start, node)
  41. def find_globals_and_nonlocals(node, globs, nonlocals, code, version):
  42. """search a node of parse tree to find variable names that need a
  43. either 'global' or 'nonlocal' statements added."""
  44. for n in node:
  45. if isinstance(n, SyntaxTree):
  46. globs, nonlocals = find_globals_and_nonlocals(
  47. n, globs, nonlocals, code, version
  48. )
  49. elif n.kind in read_global_ops:
  50. globs.add(n.pattr)
  51. elif (
  52. version >= (3, 0)
  53. and n.kind in nonglobal_ops
  54. and n.pattr in code.co_freevars
  55. and n.pattr != code.co_name
  56. and code.co_name != "<lambda>"
  57. ):
  58. nonlocals.add(n.pattr)
  59. return globs, nonlocals
  60. # def find_globals(node, globs, global_ops=mkfunc_globals):
  61. # """Find globals in this statement."""
  62. # for n in node:
  63. # # print("XXX", n.kind, global_ops)
  64. # if isinstance(n, SyntaxTree):
  65. # # FIXME: do I need a caser for n.kind="mkfunc"?
  66. # if n.kind in ("if_exp_lambda", "return_expr_lambda"):
  67. # globs = find_globals(n, globs, lambda_body_globals)
  68. # else:
  69. # globs = find_globals(n, globs, global_ops)
  70. # elif n.kind in frozenset(global_ops):
  71. # globs.add(n.pattr)
  72. # return globs
  73. def find_none(node):
  74. for n in node:
  75. if isinstance(n, SyntaxTree):
  76. if n not in ("return_stmt", "return_if_stmt"):
  77. if find_none(n):
  78. return True
  79. elif n.kind == "LOAD_CONST" and n.pattr is None:
  80. return True
  81. return False
  82. def flatten_list(node):
  83. """
  84. List of expressions may be nested in groups of 32 and 1024
  85. items. flatten that out and return the list
  86. """
  87. flat_elems = []
  88. for elem in node:
  89. if elem == "expr1024":
  90. for subelem in elem:
  91. assert subelem == "expr32"
  92. for subsubelem in subelem:
  93. flat_elems.append(subsubelem)
  94. elif elem == "expr32":
  95. for subelem in elem:
  96. assert subelem == "expr"
  97. flat_elems.append(subelem)
  98. else:
  99. flat_elems.append(elem)
  100. pass
  101. pass
  102. return flat_elems
  103. def is_lambda_mode(compile_mode: str) -> bool:
  104. return compile_mode in ("dictcomp", "genexpr", "lambda", "listcomp", "setcomp")
  105. def strip_quotes(s: str) -> str:
  106. if s.startswith("'''") and s.endswith("'''"):
  107. s = s[3:-3]
  108. elif s.startswith('"""') and s.endswith('"""'):
  109. s = s[3:-3]
  110. elif s.startswith("'") and s.endswith("'"):
  111. s = s[1:-1]
  112. elif s.startswith('"') and s.endswith('"'):
  113. s = s[1:-1]
  114. pass
  115. return s
  116. def print_docstring(self, indent, docstring):
  117. quote = '"""'
  118. if docstring.find(quote) >= 0:
  119. if docstring.find("'''") == -1:
  120. quote = "'''"
  121. self.write(indent)
  122. docstring = repr(docstring.expandtabs())[1:-1]
  123. for (orig, replace) in (
  124. ("\\\\", "\t"),
  125. ("\\r\\n", "\n"),
  126. ("\\n", "\n"),
  127. ("\\r", "\n"),
  128. ('\\"', '"'),
  129. ("\\'", "'"),
  130. ):
  131. docstring = docstring.replace(orig, replace)
  132. # Do a raw string if there are backslashes but no other escaped characters:
  133. # also check some edge cases
  134. if (
  135. "\t" in docstring
  136. and "\\" not in docstring
  137. and len(docstring) >= 2
  138. and docstring[-1] != "\t"
  139. and (docstring[-1] != '"' or docstring[-2] == "\t")
  140. ):
  141. self.write("r") # raw string
  142. # Restore backslashes unescaped since raw
  143. docstring = docstring.replace("\t", "\\")
  144. else:
  145. # Escape the last character if it is the same as the
  146. # triple quote character.
  147. quote1 = quote[-1]
  148. if len(docstring) and docstring[-1] == quote1:
  149. docstring = docstring[:-1] + "\\" + quote1
  150. # Escape triple quote when needed
  151. if quote == '"""':
  152. replace_str = '\\"""'
  153. else:
  154. assert quote == "'''"
  155. replace_str = "\\'''"
  156. docstring = docstring.replace(quote, replace_str)
  157. docstring = docstring.replace("\t", "\\\\")
  158. lines = docstring.split("\n")
  159. self.write(quote)
  160. if len(lines) == 0:
  161. self.println(quote)
  162. elif len(lines) == 1:
  163. self.println(lines[0], quote)
  164. else:
  165. self.println(lines[0])
  166. for line in lines[1:-1]:
  167. if line:
  168. self.println(line)
  169. else:
  170. self.println("\n\n")
  171. pass
  172. pass
  173. self.println(lines[-1], quote)
  174. return True
  175. # if __name__ == '__main__':
  176. # from io import StringIO
  177. # class PrintFake():
  178. # def __init__(self):
  179. # self.pending_newlines = 0
  180. # self.f = StringIO()
  181. # def write(self, *data):
  182. # if (len(data) == 0) or (len(data) == 1 and data[0] == ''):
  183. # return
  184. # out = ''.join((str(j) for j in data))
  185. # n = 0
  186. # for i in out:
  187. # if i == '\n':
  188. # n += 1
  189. # if n == len(out):
  190. # self.pending_newlines = max(self.pending_newlines, n)
  191. # return
  192. # elif n:
  193. # self.pending_newlines = max(self.pending_newlines, n)
  194. # out = out[n:]
  195. # break
  196. # else:
  197. # break
  198. # if self.pending_newlines > 0:
  199. # self.f.write('\n'*self.pending_newlines)
  200. # self.pending_newlines = 0
  201. # for i in out[::-1]:
  202. # if i == '\n':
  203. # self.pending_newlines += 1
  204. # else:
  205. # break
  206. # if self.pending_newlines:
  207. # out = out[:-self.pending_newlines]
  208. # self.f.write(out)
  209. # def println(self, *data):
  210. # if data and not(len(data) == 1 and data[0] ==''):
  211. # self.write(*data)
  212. # self.pending_newlines = max(self.pending_newlines, 1)
  213. # return
  214. # pass
  215. # for doc in (
  216. # "Now is the time",
  217. # r'''func placeholder - with ("""\nstring\n""")''',
  218. # r'''func placeholder - ' and with ("""\nstring\n""")''',
  219. # r"""func placeholder - ' and with ('''\nstring\n''') and \"\"\"\nstring\n\"\"\" """
  220. # ):
  221. # o = PrintFake()
  222. # print_docstring(o, ' ', doc)
  223. # print(o.f.getvalue())