bytecode_analysis.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. """
  2. This module provides utilities for analyzing and optimizing Python bytecode.
  3. Key functionality includes:
  4. - Dead code elimination
  5. - Jump instruction optimization
  6. - Stack size analysis and verification
  7. - Live variable analysis
  8. - Line number propagation and cleanup
  9. - Exception table handling for Python 3.11+
  10. The utilities in this module are used to analyze and transform bytecode
  11. for better performance while maintaining correct semantics.
  12. """
  13. import bisect
  14. import dataclasses
  15. import dis
  16. import itertools
  17. import sys
  18. from typing import Any, TYPE_CHECKING, Union
  19. if TYPE_CHECKING:
  20. # TODO(lucaskabela): consider moving Instruction into this file
  21. # and refactoring in callsite; that way we don't have to guard this import
  22. from .bytecode_transformation import Instruction
  23. TERMINAL_OPCODES = {
  24. dis.opmap["RETURN_VALUE"],
  25. dis.opmap["JUMP_FORWARD"],
  26. dis.opmap["RAISE_VARARGS"],
  27. # TODO(jansel): double check exception handling
  28. }
  29. TERMINAL_OPCODES.add(dis.opmap["RERAISE"])
  30. if sys.version_info >= (3, 11):
  31. TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD"])
  32. TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"])
  33. else:
  34. TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
  35. if (3, 12) <= sys.version_info < (3, 14):
  36. TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"])
  37. if sys.version_info >= (3, 13):
  38. TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD_NO_INTERRUPT"])
  39. JUMP_OPCODES = set(dis.hasjrel + dis.hasjabs)
  40. JUMP_OPNAMES = {dis.opname[opcode] for opcode in JUMP_OPCODES}
  41. HASLOCAL = set(dis.haslocal)
  42. HASFREE = set(dis.hasfree)
  43. stack_effect = dis.stack_effect
  44. def get_indexof(insts: list["Instruction"]) -> dict["Instruction", int]:
  45. """
  46. Get a mapping from instruction memory address to index in instruction list.
  47. Additionally checks that each instruction only appears once in the list.
  48. """
  49. # pyrefly: ignore [implicit-any]
  50. indexof = {}
  51. for i, inst in enumerate(insts):
  52. assert inst not in indexof
  53. indexof[inst] = i
  54. return indexof
  55. def remove_dead_code(instructions: list["Instruction"]) -> list["Instruction"]:
  56. """Dead code elimination"""
  57. indexof = get_indexof(instructions)
  58. live_code = set()
  59. def find_live_code(start: int) -> None:
  60. for i in range(start, len(instructions)):
  61. if i in live_code:
  62. return
  63. live_code.add(i)
  64. inst = instructions[i]
  65. if inst.exn_tab_entry:
  66. find_live_code(indexof[inst.exn_tab_entry.target])
  67. if inst.opcode in JUMP_OPCODES:
  68. assert inst.target is not None
  69. find_live_code(indexof[inst.target])
  70. if inst.opcode in TERMINAL_OPCODES:
  71. return
  72. find_live_code(0)
  73. # change exception table entries if start/end instructions are dead
  74. # assumes that exception table entries have been propagated,
  75. # e.g. with bytecode_transformation.propagate_inst_exn_table_entries,
  76. # and that instructions with an exn_tab_entry lies within its start/end.
  77. if sys.version_info >= (3, 11):
  78. live_idx = sorted(live_code)
  79. for i, inst in enumerate(instructions):
  80. if i in live_code and inst.exn_tab_entry:
  81. # find leftmost live instruction >= start
  82. start_idx = bisect.bisect_left(
  83. live_idx, indexof[inst.exn_tab_entry.start]
  84. )
  85. assert start_idx < len(live_idx)
  86. # find rightmost live instruction <= end
  87. end_idx = (
  88. bisect.bisect_right(live_idx, indexof[inst.exn_tab_entry.end]) - 1
  89. )
  90. assert end_idx >= 0
  91. assert live_idx[start_idx] <= i <= live_idx[end_idx]
  92. inst.exn_tab_entry.start = instructions[live_idx[start_idx]]
  93. inst.exn_tab_entry.end = instructions[live_idx[end_idx]]
  94. return [inst for i, inst in enumerate(instructions) if i in live_code]
  95. def remove_pointless_jumps(instructions: list["Instruction"]) -> list["Instruction"]:
  96. """Eliminate jumps to the next instruction"""
  97. pointless_jumps = {
  98. id(a)
  99. for a, b in itertools.pairwise(instructions)
  100. if a.opname == "JUMP_ABSOLUTE" and a.target is b
  101. }
  102. return [inst for inst in instructions if id(inst) not in pointless_jumps]
  103. def propagate_line_nums(instructions: list["Instruction"]) -> None:
  104. """Ensure every instruction has line number set in case some are removed"""
  105. cur_line_no = None
  106. def populate_line_num(inst: "Instruction") -> None:
  107. nonlocal cur_line_no
  108. if inst.starts_line:
  109. cur_line_no = inst.starts_line
  110. inst.starts_line = cur_line_no
  111. for inst in instructions:
  112. populate_line_num(inst)
  113. def remove_extra_line_nums(instructions: list["Instruction"]) -> None:
  114. """Remove extra starts line properties before packing bytecode"""
  115. cur_line_no = None
  116. def remove_line_num(inst: "Instruction") -> None:
  117. nonlocal cur_line_no
  118. if inst.starts_line is None:
  119. return
  120. elif inst.starts_line == cur_line_no:
  121. inst.starts_line = None
  122. else:
  123. cur_line_no = inst.starts_line
  124. for inst in instructions:
  125. remove_line_num(inst)
  126. @dataclasses.dataclass
  127. class ReadsWrites:
  128. reads: set[Any]
  129. writes: set[Any]
  130. visited: set[Any]
  131. def livevars_analysis(
  132. instructions: list["Instruction"], instruction: "Instruction"
  133. ) -> set[Any]:
  134. indexof = get_indexof(instructions)
  135. must = ReadsWrites(set(), set(), set())
  136. may = ReadsWrites(set(), set(), set())
  137. def walk(state: ReadsWrites, start: int) -> None:
  138. if start in state.visited:
  139. return
  140. state.visited.add(start)
  141. for i in range(start, len(instructions)):
  142. inst = instructions[i]
  143. if inst.opcode in HASLOCAL or inst.opcode in HASFREE:
  144. if "LOAD" in inst.opname or "DELETE" in inst.opname:
  145. if inst.argval not in must.writes:
  146. state.reads.add(inst.argval)
  147. elif "STORE" in inst.opname:
  148. state.writes.add(inst.argval)
  149. elif inst.opname == "MAKE_CELL":
  150. pass
  151. else:
  152. raise NotImplementedError(f"unhandled {inst.opname}")
  153. if inst.exn_tab_entry:
  154. walk(may, indexof[inst.exn_tab_entry.target])
  155. if inst.opcode in JUMP_OPCODES:
  156. assert inst.target is not None
  157. walk(may, indexof[inst.target])
  158. state = may
  159. if inst.opcode in TERMINAL_OPCODES:
  160. return
  161. walk(must, indexof[instruction])
  162. return must.reads | may.reads
  163. @dataclasses.dataclass
  164. class FixedPointBox:
  165. value: bool = True
  166. @dataclasses.dataclass
  167. class StackSize:
  168. low: Union[int, float]
  169. high: Union[int, float]
  170. fixed_point: FixedPointBox
  171. def zero(self) -> None:
  172. self.low = 0
  173. self.high = 0
  174. self.fixed_point.value = False
  175. def offset_of(self, other: "StackSize", n: int) -> None:
  176. prior = (self.low, self.high)
  177. self.low = min(self.low, other.low + n)
  178. self.high = max(self.high, other.high + n)
  179. if (self.low, self.high) != prior:
  180. self.fixed_point.value = False
  181. def exn_tab_jump(self, depth: int) -> None:
  182. prior = (self.low, self.high)
  183. self.low = min(self.low, depth)
  184. self.high = max(self.high, depth)
  185. if (self.low, self.high) != prior:
  186. self.fixed_point.value = False
  187. def stacksize_analysis(instructions: list["Instruction"]) -> Union[int, float]:
  188. assert instructions
  189. fixed_point = FixedPointBox()
  190. stack_sizes = {
  191. inst: StackSize(float("inf"), float("-inf"), fixed_point)
  192. for inst in instructions
  193. }
  194. stack_sizes[instructions[0]].zero()
  195. for _ in range(100):
  196. if fixed_point.value:
  197. break
  198. fixed_point.value = True
  199. for inst, next_inst in zip(instructions, instructions[1:] + [None]):
  200. stack_size = stack_sizes[inst]
  201. if inst.opcode not in TERMINAL_OPCODES:
  202. assert next_inst is not None, f"missing next inst: {inst}"
  203. eff = stack_effect(inst.opcode, inst.arg, jump=False)
  204. stack_sizes[next_inst].offset_of(stack_size, eff)
  205. if inst.opcode in JUMP_OPCODES:
  206. assert inst.target is not None, f"missing target: {inst}"
  207. stack_sizes[inst.target].offset_of(
  208. stack_size, stack_effect(inst.opcode, inst.arg, jump=True)
  209. )
  210. if inst.exn_tab_entry:
  211. # see https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
  212. # on why depth is computed this way.
  213. depth = inst.exn_tab_entry.depth + int(inst.exn_tab_entry.lasti) + 1
  214. stack_sizes[inst.exn_tab_entry.target].exn_tab_jump(depth)
  215. low = min(x.low for x in stack_sizes.values())
  216. high = max(x.high for x in stack_sizes.values())
  217. assert fixed_point.value, "failed to reach fixed point"
  218. assert low >= 0
  219. return high