TokenStreamRewriter.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. #
  2. # Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
  3. # Use of this file is governed by the BSD 3-clause license that
  4. # can be found in the LICENSE.txt file in the project root.
  5. #
  6. from io import StringIO
  7. from antlr4.Token import Token
  8. from antlr4.CommonTokenStream import CommonTokenStream
  9. class TokenStreamRewriter(object):
  10. __slots__ = ('tokens', 'programs', 'lastRewriteTokenIndexes')
  11. DEFAULT_PROGRAM_NAME = "default"
  12. PROGRAM_INIT_SIZE = 100
  13. MIN_TOKEN_INDEX = 0
  14. def __init__(self, tokens):
  15. """
  16. :type tokens: antlr4.BufferedTokenStream.BufferedTokenStream
  17. :param tokens:
  18. :return:
  19. """
  20. super(TokenStreamRewriter, self).__init__()
  21. self.tokens = tokens
  22. self.programs = {self.DEFAULT_PROGRAM_NAME: []}
  23. self.lastRewriteTokenIndexes = {}
  24. def getTokenStream(self):
  25. return self.tokens
  26. def rollback(self, instruction_index, program_name):
  27. ins = self.programs.get(program_name, None)
  28. if ins:
  29. self.programs[program_name] = ins[self.MIN_TOKEN_INDEX: instruction_index]
  30. def deleteProgram(self, program_name=DEFAULT_PROGRAM_NAME):
  31. self.rollback(self.MIN_TOKEN_INDEX, program_name)
  32. def insertAfterToken(self, token, text, program_name=DEFAULT_PROGRAM_NAME):
  33. self.insertAfter(token.tokenIndex, text, program_name)
  34. def insertAfter(self, index, text, program_name=DEFAULT_PROGRAM_NAME):
  35. op = self.InsertAfterOp(self.tokens, index + 1, text)
  36. rewrites = self.getProgram(program_name)
  37. op.instructionIndex = len(rewrites)
  38. rewrites.append(op)
  39. def insertBeforeIndex(self, index, text):
  40. self.insertBefore(self.DEFAULT_PROGRAM_NAME, index, text)
  41. def insertBeforeToken(self, token, text, program_name=DEFAULT_PROGRAM_NAME):
  42. self.insertBefore(program_name, token.tokenIndex, text)
  43. def insertBefore(self, program_name, index, text):
  44. op = self.InsertBeforeOp(self.tokens, index, text)
  45. rewrites = self.getProgram(program_name)
  46. op.instructionIndex = len(rewrites)
  47. rewrites.append(op)
  48. def replaceIndex(self, index, text):
  49. self.replace(self.DEFAULT_PROGRAM_NAME, index, index, text)
  50. def replaceRange(self, from_idx, to_idx, text):
  51. self.replace(self.DEFAULT_PROGRAM_NAME, from_idx, to_idx, text)
  52. def replaceSingleToken(self, token, text):
  53. self.replace(self.DEFAULT_PROGRAM_NAME, token.tokenIndex, token.tokenIndex, text)
  54. def replaceRangeTokens(self, from_token, to_token, text, program_name=DEFAULT_PROGRAM_NAME):
  55. self.replace(program_name, from_token.tokenIndex, to_token.tokenIndex, text)
  56. def replace(self, program_name, from_idx, to_idx, text):
  57. if any((from_idx > to_idx, from_idx < 0, to_idx < 0, to_idx >= len(self.tokens.tokens))):
  58. raise ValueError(
  59. 'replace: range invalid: {}..{}(size={})'.format(from_idx, to_idx, len(self.tokens.tokens)))
  60. op = self.ReplaceOp(from_idx, to_idx, self.tokens, text)
  61. rewrites = self.getProgram(program_name)
  62. op.instructionIndex = len(rewrites)
  63. rewrites.append(op)
  64. def deleteToken(self, token):
  65. self.delete(self.DEFAULT_PROGRAM_NAME, token, token)
  66. def deleteIndex(self, index):
  67. self.delete(self.DEFAULT_PROGRAM_NAME, index, index)
  68. def delete(self, program_name, from_idx, to_idx):
  69. if isinstance(from_idx, Token):
  70. self.replace(program_name, from_idx.tokenIndex, to_idx.tokenIndex, "")
  71. else:
  72. self.replace(program_name, from_idx, to_idx, "")
  73. def lastRewriteTokenIndex(self, program_name=DEFAULT_PROGRAM_NAME):
  74. return self.lastRewriteTokenIndexes.get(program_name, -1)
  75. def setLastRewriteTokenIndex(self, program_name, i):
  76. self.lastRewriteTokenIndexes[program_name] = i
  77. def getProgram(self, program_name):
  78. return self.programs.setdefault(program_name, [])
  79. def getDefaultText(self):
  80. return self.getText(self.DEFAULT_PROGRAM_NAME, 0, len(self.tokens.tokens) - 1)
  81. def getText(self, program_name, start:int, stop:int):
  82. """
  83. :return: the text in tokens[start, stop](closed interval)
  84. """
  85. rewrites = self.programs.get(program_name)
  86. # ensure start/end are in range
  87. if stop > len(self.tokens.tokens) - 1:
  88. stop = len(self.tokens.tokens) - 1
  89. if start < 0:
  90. start = 0
  91. # if no instructions to execute
  92. if not rewrites: return self.tokens.getText(start, stop)
  93. buf = StringIO()
  94. indexToOp = self._reduceToSingleOperationPerIndex(rewrites)
  95. i = start
  96. while all((i <= stop, i < len(self.tokens.tokens))):
  97. op = indexToOp.pop(i, None)
  98. token = self.tokens.get(i)
  99. if op is None:
  100. if token.type != Token.EOF: buf.write(token.text)
  101. i += 1
  102. else:
  103. i = op.execute(buf)
  104. if stop == len(self.tokens.tokens)-1:
  105. for op in indexToOp.values():
  106. if op.index >= len(self.tokens.tokens)-1: buf.write(op.text)
  107. return buf.getvalue()
  108. def _reduceToSingleOperationPerIndex(self, rewrites):
  109. # Walk replaces
  110. for i, rop in enumerate(rewrites):
  111. if any((rop is None, not isinstance(rop, TokenStreamRewriter.ReplaceOp))):
  112. continue
  113. # Wipe prior inserts within range
  114. inserts = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.InsertBeforeOp)]
  115. for iop in inserts:
  116. if iop.index == rop.index:
  117. rewrites[iop.instructionIndex] = None
  118. rop.text = '{}{}'.format(iop.text, rop.text)
  119. elif all((iop.index > rop.index, iop.index <= rop.last_index)):
  120. rewrites[iop.instructionIndex] = None
  121. # Drop any prior replaces contained within
  122. prevReplaces = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.ReplaceOp)]
  123. for prevRop in prevReplaces:
  124. if all((prevRop.index >= rop.index, prevRop.last_index <= rop.last_index)):
  125. rewrites[prevRop.instructionIndex] = None
  126. continue
  127. isDisjoint = any((prevRop.last_index<rop.index, prevRop.index>rop.last_index))
  128. if all((prevRop.text is None, rop.text is None, not isDisjoint)):
  129. rewrites[prevRop.instructionIndex] = None
  130. rop.index = min(prevRop.index, rop.index)
  131. rop.last_index = min(prevRop.last_index, rop.last_index)
  132. print('New rop {}'.format(rop))
  133. elif (not(isDisjoint)):
  134. raise ValueError("replace op boundaries of {} overlap with previous {}".format(rop, prevRop))
  135. # Walk inserts
  136. for i, iop in enumerate(rewrites):
  137. if any((iop is None, not isinstance(iop, TokenStreamRewriter.InsertBeforeOp))):
  138. continue
  139. prevInserts = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.InsertBeforeOp)]
  140. for prev_index, prevIop in enumerate(prevInserts):
  141. if prevIop.index == iop.index and type(prevIop) is TokenStreamRewriter.InsertBeforeOp:
  142. iop.text += prevIop.text
  143. rewrites[prev_index] = None
  144. elif prevIop.index == iop.index and type(prevIop) is TokenStreamRewriter.InsertAfterOp:
  145. iop.text = prevIop.text + iop.text
  146. rewrites[prev_index] = None
  147. # look for replaces where iop.index is in range; error
  148. prevReplaces = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.ReplaceOp)]
  149. for rop in prevReplaces:
  150. if iop.index == rop.index:
  151. rop.text = iop.text + rop.text
  152. rewrites[i] = None
  153. continue
  154. if all((iop.index >= rop.index, iop.index <= rop.last_index)):
  155. raise ValueError("insert op {} within boundaries of previous {}".format(iop, rop))
  156. reduced = {}
  157. for i, op in enumerate(rewrites):
  158. if op is None: continue
  159. if reduced.get(op.index): raise ValueError('should be only one op per index')
  160. reduced[op.index] = op
  161. return reduced
  162. class RewriteOperation(object):
  163. __slots__ = ('tokens', 'index', 'text', 'instructionIndex')
  164. def __init__(self, tokens, index, text=""):
  165. """
  166. :type tokens: CommonTokenStream
  167. :param tokens:
  168. :param index:
  169. :param text:
  170. :return:
  171. """
  172. self.tokens = tokens
  173. self.index = index
  174. self.text = text
  175. self.instructionIndex = 0
  176. def execute(self, buf):
  177. """
  178. :type buf: StringIO.StringIO
  179. :param buf:
  180. :return:
  181. """
  182. return self.index
  183. def __str__(self):
  184. return '<{}@{}:"{}">'.format(self.__class__.__name__, self.tokens.get(self.index), self.text)
  185. class InsertBeforeOp(RewriteOperation):
  186. def __init__(self, tokens, index, text=""):
  187. super(TokenStreamRewriter.InsertBeforeOp, self).__init__(tokens, index, text)
  188. def execute(self, buf):
  189. buf.write(self.text)
  190. if self.tokens.get(self.index).type != Token.EOF:
  191. buf.write(self.tokens.get(self.index).text)
  192. return self.index + 1
  193. class InsertAfterOp(InsertBeforeOp):
  194. pass
  195. class ReplaceOp(RewriteOperation):
  196. __slots__ = 'last_index'
  197. def __init__(self, from_idx, to_idx, tokens, text):
  198. super(TokenStreamRewriter.ReplaceOp, self).__init__(tokens, from_idx, text)
  199. self.last_index = to_idx
  200. def execute(self, buf):
  201. if self.text:
  202. buf.write(self.text)
  203. return self.last_index + 1
  204. def __str__(self):
  205. if self.text:
  206. return '<ReplaceOp@{}..{}:"{}">'.format(self.tokens.get(self.index), self.tokens.get(self.last_index),
  207. self.text)