| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- #
- # Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
- # Use of this file is governed by the BSD 3-clause license that
- # can be found in the LICENSE.txt file in the project root.
- #
- from io import StringIO
- from antlr4.Token import Token
- from antlr4.CommonTokenStream import CommonTokenStream
- class TokenStreamRewriter(object):
- __slots__ = ('tokens', 'programs', 'lastRewriteTokenIndexes')
- DEFAULT_PROGRAM_NAME = "default"
- PROGRAM_INIT_SIZE = 100
- MIN_TOKEN_INDEX = 0
- def __init__(self, tokens):
- """
- :type tokens: antlr4.BufferedTokenStream.BufferedTokenStream
- :param tokens:
- :return:
- """
- super(TokenStreamRewriter, self).__init__()
- self.tokens = tokens
- self.programs = {self.DEFAULT_PROGRAM_NAME: []}
- self.lastRewriteTokenIndexes = {}
- def getTokenStream(self):
- return self.tokens
- def rollback(self, instruction_index, program_name):
- ins = self.programs.get(program_name, None)
- if ins:
- self.programs[program_name] = ins[self.MIN_TOKEN_INDEX: instruction_index]
- def deleteProgram(self, program_name=DEFAULT_PROGRAM_NAME):
- self.rollback(self.MIN_TOKEN_INDEX, program_name)
- def insertAfterToken(self, token, text, program_name=DEFAULT_PROGRAM_NAME):
- self.insertAfter(token.tokenIndex, text, program_name)
- def insertAfter(self, index, text, program_name=DEFAULT_PROGRAM_NAME):
- op = self.InsertAfterOp(self.tokens, index + 1, text)
- rewrites = self.getProgram(program_name)
- op.instructionIndex = len(rewrites)
- rewrites.append(op)
- def insertBeforeIndex(self, index, text):
- self.insertBefore(self.DEFAULT_PROGRAM_NAME, index, text)
- def insertBeforeToken(self, token, text, program_name=DEFAULT_PROGRAM_NAME):
- self.insertBefore(program_name, token.tokenIndex, text)
- def insertBefore(self, program_name, index, text):
- op = self.InsertBeforeOp(self.tokens, index, text)
- rewrites = self.getProgram(program_name)
- op.instructionIndex = len(rewrites)
- rewrites.append(op)
- def replaceIndex(self, index, text):
- self.replace(self.DEFAULT_PROGRAM_NAME, index, index, text)
- def replaceRange(self, from_idx, to_idx, text):
- self.replace(self.DEFAULT_PROGRAM_NAME, from_idx, to_idx, text)
- def replaceSingleToken(self, token, text):
- self.replace(self.DEFAULT_PROGRAM_NAME, token.tokenIndex, token.tokenIndex, text)
- def replaceRangeTokens(self, from_token, to_token, text, program_name=DEFAULT_PROGRAM_NAME):
- self.replace(program_name, from_token.tokenIndex, to_token.tokenIndex, text)
- def replace(self, program_name, from_idx, to_idx, text):
- if any((from_idx > to_idx, from_idx < 0, to_idx < 0, to_idx >= len(self.tokens.tokens))):
- raise ValueError(
- 'replace: range invalid: {}..{}(size={})'.format(from_idx, to_idx, len(self.tokens.tokens)))
- op = self.ReplaceOp(from_idx, to_idx, self.tokens, text)
- rewrites = self.getProgram(program_name)
- op.instructionIndex = len(rewrites)
- rewrites.append(op)
- def deleteToken(self, token):
- self.delete(self.DEFAULT_PROGRAM_NAME, token, token)
- def deleteIndex(self, index):
- self.delete(self.DEFAULT_PROGRAM_NAME, index, index)
- def delete(self, program_name, from_idx, to_idx):
- if isinstance(from_idx, Token):
- self.replace(program_name, from_idx.tokenIndex, to_idx.tokenIndex, "")
- else:
- self.replace(program_name, from_idx, to_idx, "")
- def lastRewriteTokenIndex(self, program_name=DEFAULT_PROGRAM_NAME):
- return self.lastRewriteTokenIndexes.get(program_name, -1)
- def setLastRewriteTokenIndex(self, program_name, i):
- self.lastRewriteTokenIndexes[program_name] = i
- def getProgram(self, program_name):
- return self.programs.setdefault(program_name, [])
- def getDefaultText(self):
- return self.getText(self.DEFAULT_PROGRAM_NAME, 0, len(self.tokens.tokens) - 1)
- def getText(self, program_name, start:int, stop:int):
- """
- :return: the text in tokens[start, stop](closed interval)
- """
- rewrites = self.programs.get(program_name)
- # ensure start/end are in range
- if stop > len(self.tokens.tokens) - 1:
- stop = len(self.tokens.tokens) - 1
- if start < 0:
- start = 0
- # if no instructions to execute
- if not rewrites: return self.tokens.getText(start, stop)
- buf = StringIO()
- indexToOp = self._reduceToSingleOperationPerIndex(rewrites)
- i = start
- while all((i <= stop, i < len(self.tokens.tokens))):
- op = indexToOp.pop(i, None)
- token = self.tokens.get(i)
- if op is None:
- if token.type != Token.EOF: buf.write(token.text)
- i += 1
- else:
- i = op.execute(buf)
- if stop == len(self.tokens.tokens)-1:
- for op in indexToOp.values():
- if op.index >= len(self.tokens.tokens)-1: buf.write(op.text)
- return buf.getvalue()
- def _reduceToSingleOperationPerIndex(self, rewrites):
- # Walk replaces
- for i, rop in enumerate(rewrites):
- if any((rop is None, not isinstance(rop, TokenStreamRewriter.ReplaceOp))):
- continue
- # Wipe prior inserts within range
- inserts = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.InsertBeforeOp)]
- for iop in inserts:
- if iop.index == rop.index:
- rewrites[iop.instructionIndex] = None
- rop.text = '{}{}'.format(iop.text, rop.text)
- elif all((iop.index > rop.index, iop.index <= rop.last_index)):
- rewrites[iop.instructionIndex] = None
- # Drop any prior replaces contained within
- prevReplaces = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.ReplaceOp)]
- for prevRop in prevReplaces:
- if all((prevRop.index >= rop.index, prevRop.last_index <= rop.last_index)):
- rewrites[prevRop.instructionIndex] = None
- continue
- isDisjoint = any((prevRop.last_index<rop.index, prevRop.index>rop.last_index))
- if all((prevRop.text is None, rop.text is None, not isDisjoint)):
- rewrites[prevRop.instructionIndex] = None
- rop.index = min(prevRop.index, rop.index)
- rop.last_index = min(prevRop.last_index, rop.last_index)
- print('New rop {}'.format(rop))
- elif (not(isDisjoint)):
- raise ValueError("replace op boundaries of {} overlap with previous {}".format(rop, prevRop))
- # Walk inserts
- for i, iop in enumerate(rewrites):
- if any((iop is None, not isinstance(iop, TokenStreamRewriter.InsertBeforeOp))):
- continue
- prevInserts = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.InsertBeforeOp)]
- for prev_index, prevIop in enumerate(prevInserts):
- if prevIop.index == iop.index and type(prevIop) is TokenStreamRewriter.InsertBeforeOp:
- iop.text += prevIop.text
- rewrites[prev_index] = None
- elif prevIop.index == iop.index and type(prevIop) is TokenStreamRewriter.InsertAfterOp:
- iop.text = prevIop.text + iop.text
- rewrites[prev_index] = None
- # look for replaces where iop.index is in range; error
- prevReplaces = [op for op in rewrites[:i] if isinstance(op, TokenStreamRewriter.ReplaceOp)]
- for rop in prevReplaces:
- if iop.index == rop.index:
- rop.text = iop.text + rop.text
- rewrites[i] = None
- continue
- if all((iop.index >= rop.index, iop.index <= rop.last_index)):
- raise ValueError("insert op {} within boundaries of previous {}".format(iop, rop))
- reduced = {}
- for i, op in enumerate(rewrites):
- if op is None: continue
- if reduced.get(op.index): raise ValueError('should be only one op per index')
- reduced[op.index] = op
- return reduced
- class RewriteOperation(object):
- __slots__ = ('tokens', 'index', 'text', 'instructionIndex')
- def __init__(self, tokens, index, text=""):
- """
- :type tokens: CommonTokenStream
- :param tokens:
- :param index:
- :param text:
- :return:
- """
- self.tokens = tokens
- self.index = index
- self.text = text
- self.instructionIndex = 0
- def execute(self, buf):
- """
- :type buf: StringIO.StringIO
- :param buf:
- :return:
- """
- return self.index
- def __str__(self):
- return '<{}@{}:"{}">'.format(self.__class__.__name__, self.tokens.get(self.index), self.text)
- class InsertBeforeOp(RewriteOperation):
- def __init__(self, tokens, index, text=""):
- super(TokenStreamRewriter.InsertBeforeOp, self).__init__(tokens, index, text)
- def execute(self, buf):
- buf.write(self.text)
- if self.tokens.get(self.index).type != Token.EOF:
- buf.write(self.tokens.get(self.index).text)
- return self.index + 1
- class InsertAfterOp(InsertBeforeOp):
- pass
- class ReplaceOp(RewriteOperation):
- __slots__ = 'last_index'
- def __init__(self, from_idx, to_idx, tokens, text):
- super(TokenStreamRewriter.ReplaceOp, self).__init__(tokens, from_idx, text)
- self.last_index = to_idx
- def execute(self, buf):
- if self.text:
- buf.write(self.text)
- return self.last_index + 1
- def __str__(self):
- if self.text:
- return '<ReplaceOp@{}..{}:"{}">'.format(self.tokens.get(self.index), self.tokens.get(self.last_index),
- self.text)
|