_utils.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import ast
  2. import sys
  3. import dis
  4. from typing import cast, Any,Iterator
  5. import types
  6. def assert_(condition, message=""):
  7. # type: (Any, str) -> None
  8. """
  9. Like an assert statement, but unaffected by -O
  10. :param condition: value that is expected to be truthy
  11. :type message: Any
  12. """
  13. if not condition:
  14. raise AssertionError(str(message))
  15. if sys.version_info >= (3, 4):
  16. # noinspection PyUnresolvedReferences
  17. _get_instructions = dis.get_instructions
  18. from dis import Instruction as _Instruction
  19. class Instruction(_Instruction):
  20. lineno = None # type: int
  21. else:
  22. from collections import namedtuple
  23. class Instruction(namedtuple('Instruction', 'offset argval opname starts_line')):
  24. lineno = None # type: int
  25. from dis import HAVE_ARGUMENT, EXTENDED_ARG, hasconst, opname, findlinestarts, hasname
  26. # Based on dis.disassemble from 2.7
  27. # Left as similar as possible for easy diff
  28. def _get_instructions(co):
  29. # type: (types.CodeType) -> Iterator[Instruction]
  30. code = co.co_code
  31. linestarts = dict(findlinestarts(co))
  32. n = len(code)
  33. i = 0
  34. extended_arg = 0
  35. while i < n:
  36. offset = i
  37. c = code[i]
  38. op = ord(c)
  39. lineno = linestarts.get(i)
  40. argval = None
  41. i = i + 1
  42. if op >= HAVE_ARGUMENT:
  43. oparg = ord(code[i]) + ord(code[i + 1]) * 256 + extended_arg
  44. extended_arg = 0
  45. i = i + 2
  46. if op == EXTENDED_ARG:
  47. extended_arg = oparg * 65536
  48. if op in hasconst:
  49. argval = co.co_consts[oparg]
  50. elif op in hasname:
  51. argval = co.co_names[oparg]
  52. elif opname[op] == 'LOAD_FAST':
  53. argval = co.co_varnames[oparg]
  54. yield Instruction(offset, argval, opname[op], lineno)
  55. def get_instructions(co):
  56. # type: (types.CodeType) -> Iterator[EnhancedInstruction]
  57. lineno = co.co_firstlineno
  58. for inst in _get_instructions(co):
  59. inst = cast(EnhancedInstruction, inst)
  60. lineno = inst.starts_line or lineno
  61. assert_(lineno)
  62. inst.lineno = lineno
  63. yield inst
  64. # Type class used to expand out the definition of AST to include fields added by this library
  65. # It's not actually used for anything other than type checking though!
  66. class EnhancedAST(ast.AST):
  67. parent = None # type: EnhancedAST
  68. # Type class used to expand out the definition of AST to include fields added by this library
  69. # It's not actually used for anything other than type checking though!
  70. class EnhancedInstruction(Instruction):
  71. _copied = None # type: bool
  72. def mangled_name(node):
  73. # type: (EnhancedAST) -> str
  74. """
  75. Parameters:
  76. node: the node which should be mangled
  77. name: the name of the node
  78. Returns:
  79. The mangled name of `node`
  80. """
  81. function_class_types=(ast.FunctionDef, ast.ClassDef, ast.AsyncFunctionDef)
  82. if isinstance(node, ast.Attribute):
  83. name = node.attr
  84. elif isinstance(node, ast.Name):
  85. name = node.id
  86. elif isinstance(node, (ast.alias)):
  87. name = node.asname or node.name.split(".")[0]
  88. elif isinstance(node, function_class_types):
  89. name = node.name
  90. elif isinstance(node, ast.ExceptHandler):
  91. assert node.name
  92. name = node.name
  93. elif sys.version_info >= (3,12) and isinstance(node,ast.TypeVar):
  94. name=node.name
  95. else:
  96. raise TypeError("no node to mangle")
  97. if name.startswith("__") and not name.endswith("__"):
  98. parent,child=node.parent,node
  99. while not (isinstance(parent,ast.ClassDef) and child not in parent.bases):
  100. if not hasattr(parent,"parent"):
  101. break # pragma: no mutate
  102. parent,child=parent.parent,parent
  103. else:
  104. class_name=parent.name.lstrip("_")
  105. if class_name!="" and child not in parent.decorator_list:
  106. return "_" + class_name + name
  107. return name