helper.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. # Copyright (c) 2022-2023 Rocky Bernstein
  2. #
  3. # This program is free software: you can redistribute it and/or modify
  4. # it under the terms of the GNU General Public License as published by
  5. # the Free Software Foundation, either version 3 of the License, or
  6. # (at your option) any later version.
  7. #
  8. # This program is distributed in the hope that it will be useful,
  9. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  10. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  11. # GNU General Public License for more details.
  12. #
  13. # You should have received a copy of the GNU General Public License
  14. # along with this program. If not, see <http://www.gnu.org/licenses/>.
  15. import sys
  16. from xdis import iscode
  17. from uncompyle6.parsers.treenode import SyntaxTree
  18. minint = -sys.maxsize-1
  19. maxint = sys.maxsize
  20. read_write_global_ops = frozenset(('STORE_GLOBAL', 'DELETE_GLOBAL', 'LOAD_GLOBAL'))
  21. read_global_ops = frozenset(('STORE_GLOBAL', 'DELETE_GLOBAL'))
  22. # NOTE: we also need to check that the variable name is a free variable, not a cell variable.
  23. nonglobal_ops = frozenset(('STORE_DEREF', 'DELETE_DEREF'))
  24. def escape_string(s, quotes=('"', "'", '"""', "'''")):
  25. quote = None
  26. for q in quotes:
  27. if s.find(q) == -1:
  28. quote = q
  29. break
  30. pass
  31. if quote is None:
  32. quote = '"""'
  33. s = s.replace('"""', '\\"""')
  34. for (orig, replace) in (('\t', '\\t'),
  35. ('\n', '\\n'),
  36. ('\r', '\\r')):
  37. s = s.replace(orig, replace)
  38. return "%s%s%s" % (quote, s, quote)
  39. # FIXME: this and find_globals could be parameterized with one of the
  40. # above global ops
  41. def find_all_globals(node, globs):
  42. """Search Syntax Tree node to find variable names that are global."""
  43. for n in node:
  44. if isinstance(n, SyntaxTree):
  45. globs = find_all_globals(n, globs)
  46. elif n.kind in read_write_global_ops:
  47. globs.add(n.pattr)
  48. return globs
  49. # def find_globals(node, globs, global_ops=mkfunc_globals):
  50. # """Find globals in this statement."""
  51. # for n in node:
  52. # # print("XXX", n.kind, global_ops)
  53. # if isinstance(n, SyntaxTree):
  54. # # FIXME: do I need a caser for n.kind="mkfunc"?
  55. # if n.kind in ("if_exp_lambda", "return_expr_lambda"):
  56. # globs = find_globals(n, globs, lambda_body_globals)
  57. # else:
  58. # globs = find_globals(n, globs, global_ops)
  59. # elif n.kind in frozenset(global_ops):
  60. # globs.add(n.pattr)
  61. # return globs
  62. def find_code_node(node, start):
  63. for i in range(-start, len(node) + 1):
  64. if node[-i].kind == "LOAD_CODE":
  65. code_node = node[-i]
  66. assert iscode(code_node.attr)
  67. return code_node
  68. pass
  69. assert False, "did not find code node starting at %d in %s" % (start, node)
  70. def find_globals_and_nonlocals(node, globs, nonlocals, code, version):
  71. """search a node of parse tree to find variable names that need a
  72. either 'global' or 'nonlocal' statements added."""
  73. for n in node:
  74. if isinstance(n, SyntaxTree):
  75. globs, nonlocals = find_globals_and_nonlocals(n, globs, nonlocals,
  76. code, version)
  77. elif n.kind in read_global_ops:
  78. globs.add(n.pattr)
  79. elif (version >= (3, 0)
  80. and n.kind in nonglobal_ops
  81. and n.pattr in code.co_freevars
  82. and n.pattr != code.co_name
  83. and code.co_name != '<lambda>'):
  84. nonlocals.add(n.pattr)
  85. return globs, nonlocals
  86. def find_none(node):
  87. for n in node:
  88. if isinstance(n, SyntaxTree):
  89. if n not in ('return_stmt', 'return_if_stmt'):
  90. if find_none(n):
  91. return True
  92. elif n.kind == 'LOAD_CONST' and n.pattr is None:
  93. return True
  94. return False
  95. def flatten_list(node):
  96. """
  97. List of expressions may be nested in groups of 32 and 1024
  98. items. flatten that out and return the list
  99. """
  100. flat_elems = []
  101. for elem in node:
  102. if elem == 'expr1024':
  103. for subelem in elem:
  104. assert subelem == 'expr32'
  105. for subsubelem in subelem:
  106. flat_elems.append(subsubelem)
  107. elif elem == 'expr32':
  108. for subelem in elem:
  109. assert subelem == 'expr'
  110. flat_elems.append(subelem)
  111. else:
  112. flat_elems.append(elem)
  113. pass
  114. pass
  115. return flat_elems
  116. # Note: this is only used in Python > 3.0
  117. # Should move this somewhere more specific?
  118. def gen_function_parens_adjust(mapping_key, node):
  119. """If we can avoid the outer parenthesis
  120. of a generator function, set the node key to
  121. 'call_generator' and the caller will do the default
  122. action on that. Otherwise we do nothing.
  123. """
  124. if mapping_key.kind != 'CALL_FUNCTION_1':
  125. return
  126. args_node = node[-2]
  127. if args_node == 'pos_arg':
  128. assert args_node[0] == 'expr'
  129. n = args_node[0][0]
  130. if n == 'generator_exp':
  131. node.kind = 'call_generator'
  132. pass
  133. return
  134. def is_lambda_mode(compile_mode: str) -> bool:
  135. return compile_mode in ("dictcomp", "genexpr", "lambda", "listcomp", "setcomp")
  136. def print_docstring(self, indent, docstring):
  137. if isinstance(docstring, bytes):
  138. docstring = docstring.decode("utf8", errors="backslashreplace")
  139. quote = '"""'
  140. if docstring.find(quote) >= 0:
  141. if docstring.find("'''") == -1:
  142. quote = "'''"
  143. self.write(indent)
  144. docstring = repr(docstring.expandtabs())[1:-1]
  145. for (orig, replace) in (('\\\\', '\t'),
  146. ('\\r\\n', '\n'),
  147. ('\\n', '\n'),
  148. ('\\r', '\n'),
  149. ('\\"', '"'),
  150. ("\\'", "'")):
  151. docstring = docstring.replace(orig, replace)
  152. # Do a raw string if there are backslashes but no other escaped characters:
  153. # also check some edge cases
  154. if ('\t' in docstring
  155. and '\\' not in docstring
  156. and len(docstring) >= 2
  157. and docstring[-1] != '\t'
  158. and (docstring[-1] != '"'
  159. or docstring[-2] == '\t')):
  160. self.write('r') # raw string
  161. # Restore backslashes unescaped since raw
  162. docstring = docstring.replace('\t', '\\')
  163. else:
  164. # Escape the last character if it is the same as the
  165. # triple quote character.
  166. quote1 = quote[-1]
  167. if len(docstring) and docstring[-1] == quote1:
  168. docstring = docstring[:-1] + '\\' + quote1
  169. # Escape triple quote when needed
  170. if quote == '"""':
  171. replace_str = '\\"""'
  172. else:
  173. assert quote == "'''"
  174. replace_str = "\\'''"
  175. docstring = docstring.replace(quote, replace_str)
  176. docstring = docstring.replace('\t', '\\\\')
  177. lines = docstring.split('\n')
  178. self.write(quote)
  179. if len(lines) == 0:
  180. self.println(quote)
  181. elif len(lines) == 1:
  182. self.println(lines[0], quote)
  183. else:
  184. self.println(lines[0])
  185. for line in lines[1:-1]:
  186. if line:
  187. self.println( line )
  188. else:
  189. self.println( "\n\n" )
  190. pass
  191. pass
  192. self.println(lines[-1], quote)
  193. return True
  194. def strip_quotes(s):
  195. if s.startswith("'''") and s.endswith("'''"):
  196. s = s[3:-3]
  197. elif s.startswith('"""') and s.endswith('"""'):
  198. s = s[3:-3]
  199. elif s.startswith("'") and s.endswith("'"):
  200. s = s[1:-1]
  201. elif s.startswith('"') and s.endswith('"'):
  202. s = s[1:-1]
  203. pass
  204. return s
  205. # if __name__ == '__main__':
  206. # from io import StringIO
  207. # class PrintFake():
  208. # def __init__(self):
  209. # self.pending_newlines = 0
  210. # self.f = StringIO()
  211. # def write(self, *data):
  212. # if (len(data) == 0) or (len(data) == 1 and data[0] == ''):
  213. # return
  214. # out = ''.join((str(j) for j in data))
  215. # n = 0
  216. # for i in out:
  217. # if i == '\n':
  218. # n += 1
  219. # if n == len(out):
  220. # self.pending_newlines = max(self.pending_newlines, n)
  221. # return
  222. # elif n:
  223. # self.pending_newlines = max(self.pending_newlines, n)
  224. # out = out[n:]
  225. # break
  226. # else:
  227. # break
  228. # if self.pending_newlines > 0:
  229. # self.f.write('\n'*self.pending_newlines)
  230. # self.pending_newlines = 0
  231. # for i in out[::-1]:
  232. # if i == '\n':
  233. # self.pending_newlines += 1
  234. # else:
  235. # break
  236. # if self.pending_newlines:
  237. # out = out[:-self.pending_newlines]
  238. # self.f.write(out)
  239. # def println(self, *data):
  240. # if data and not(len(data) == 1 and data[0] ==''):
  241. # self.write(*data)
  242. # self.pending_newlines = max(self.pending_newlines, 1)
  243. # return
  244. # pass
  245. # for doc in (
  246. # "Now is the time",
  247. # r'''func placeholder - with ("""\nstring\n""")''',
  248. # r'''func placeholder - ' and with ("""\nstring\n""")''',
  249. # r"""func placeholder - ' and with ('''\nstring\n''') and \"\"\"\nstring\n\"\"\" """
  250. # ):
  251. # o = PrintFake()
  252. # print_docstring(o, ' ', doc)
  253. # print(o.f.getvalue())