ParseTreePatternMatcher.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. #
  2. # Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
  3. # Use of this file is governed by the BSD 3-clause license that
  4. # can be found in the LICENSE.txt file in the project root.
  5. #
  6. #
  7. # A tree pattern matching mechanism for ANTLR {@link ParseTree}s.
  8. #
  9. # <p>Patterns are strings of source input text with special tags representing
  10. # token or rule references such as:</p>
  11. #
  12. # <p>{@code <ID> = <expr>;}</p>
  13. #
  14. # <p>Given a pattern start rule such as {@code statement}, this object constructs
  15. # a {@link ParseTree} with placeholders for the {@code ID} and {@code expr}
  16. # subtree. Then the {@link #match} routines can compare an actual
  17. # {@link ParseTree} from a parse with this pattern. Tag {@code <ID>} matches
  18. # any {@code ID} token and tag {@code <expr>} references the result of the
  19. # {@code expr} rule (generally an instance of {@code ExprContext}.</p>
  20. #
  21. # <p>Pattern {@code x = 0;} is a similar pattern that matches the same pattern
  22. # except that it requires the identifier to be {@code x} and the expression to
  23. # be {@code 0}.</p>
  24. #
  25. # <p>The {@link #matches} routines return {@code true} or {@code false} based
  26. # upon a match for the tree rooted at the parameter sent in. The
  27. # {@link #match} routines return a {@link ParseTreeMatch} object that
  28. # contains the parse tree, the parse tree pattern, and a map from tag name to
  29. # matched nodes (more below). A subtree that fails to match, returns with
  30. # {@link ParseTreeMatch#mismatchedNode} set to the first tree node that did not
  31. # match.</p>
  32. #
  33. # <p>For efficiency, you can compile a tree pattern in string form to a
  34. # {@link ParseTreePattern} object.</p>
  35. #
  36. # <p>See {@code TestParseTreeMatcher} for lots of examples.
  37. # {@link ParseTreePattern} has two static helper methods:
  38. # {@link ParseTreePattern#findAll} and {@link ParseTreePattern#match} that
  39. # are easy to use but not super efficient because they create new
  40. # {@link ParseTreePatternMatcher} objects each time and have to compile the
  41. # pattern in string form before using it.</p>
  42. #
  43. # <p>The lexer and parser that you pass into the {@link ParseTreePatternMatcher}
  44. # constructor are used to parse the pattern in string form. The lexer converts
  45. # the {@code <ID> = <expr>;} into a sequence of four tokens (assuming lexer
  46. # throws out whitespace or puts it on a hidden channel). Be aware that the
  47. # input stream is reset for the lexer (but not the parser; a
  48. # {@link ParserInterpreter} is created to parse the input.). Any user-defined
  49. # fields you have put into the lexer might get changed when this mechanism asks
  50. # it to scan the pattern string.</p>
  51. #
  52. # <p>Normally a parser does not accept token {@code <expr>} as a valid
  53. # {@code expr} but, from the parser passed in, we create a special version of
  54. # the underlying grammar representation (an {@link ATN}) that allows imaginary
  55. # tokens representing rules ({@code <expr>}) to match entire rules. We call
  56. # these <em>bypass alternatives</em>.</p>
  57. #
  58. # <p>Delimiters are {@code <} and {@code >}, with {@code \} as the escape string
  59. # by default, but you can set them to whatever you want using
  60. # {@link #setDelimiters}. You must escape both start and stop strings
  61. # {@code \<} and {@code \>}.</p>
  62. #
  63. from antlr4.CommonTokenStream import CommonTokenStream
  64. from antlr4.InputStream import InputStream
  65. from antlr4.ParserRuleContext import ParserRuleContext
  66. from antlr4.Lexer import Lexer
  67. from antlr4.ListTokenSource import ListTokenSource
  68. from antlr4.Token import Token
  69. from antlr4.error.ErrorStrategy import BailErrorStrategy
  70. from antlr4.error.Errors import RecognitionException, ParseCancellationException
  71. from antlr4.tree.Chunk import TagChunk, TextChunk
  72. from antlr4.tree.RuleTagToken import RuleTagToken
  73. from antlr4.tree.TokenTagToken import TokenTagToken
  74. from antlr4.tree.Tree import ParseTree, TerminalNode, RuleNode
  75. # need forward declaration
  76. Parser = None
  77. ParseTreePattern = None
  78. class CannotInvokeStartRule(Exception):
  79. def __init__(self, e:Exception):
  80. super().__init__(e)
  81. class StartRuleDoesNotConsumeFullPattern(Exception):
  82. pass
  83. class ParseTreePatternMatcher(object):
  84. __slots__ = ('lexer', 'parser', 'start', 'stop', 'escape')
  85. # Constructs a {@link ParseTreePatternMatcher} or from a {@link Lexer} and
  86. # {@link Parser} object. The lexer input stream is altered for tokenizing
  87. # the tree patterns. The parser is used as a convenient mechanism to get
  88. # the grammar name, plus token, rule names.
  89. def __init__(self, lexer:Lexer, parser:Parser):
  90. self.lexer = lexer
  91. self.parser = parser
  92. self.start = "<"
  93. self.stop = ">"
  94. self.escape = "\\" # e.g., \< and \> must escape BOTH!
  95. # Set the delimiters used for marking rule and token tags within concrete
  96. # syntax used by the tree pattern parser.
  97. #
  98. # @param start The start delimiter.
  99. # @param stop The stop delimiter.
  100. # @param escapeLeft The escape sequence to use for escaping a start or stop delimiter.
  101. #
  102. # @exception IllegalArgumentException if {@code start} is {@code null} or empty.
  103. # @exception IllegalArgumentException if {@code stop} is {@code null} or empty.
  104. #
  105. def setDelimiters(self, start:str, stop:str, escapeLeft:str):
  106. if start is None or len(start)==0:
  107. raise Exception("start cannot be null or empty")
  108. if stop is None or len(stop)==0:
  109. raise Exception("stop cannot be null or empty")
  110. self.start = start
  111. self.stop = stop
  112. self.escape = escapeLeft
  113. # Does {@code pattern} matched as rule {@code patternRuleIndex} match {@code tree}?#
  114. def matchesRuleIndex(self, tree:ParseTree, pattern:str, patternRuleIndex:int):
  115. p = self.compileTreePattern(pattern, patternRuleIndex)
  116. return self.matches(tree, p)
  117. # Does {@code pattern} matched as rule patternRuleIndex match tree? Pass in a
  118. # compiled pattern instead of a string representation of a tree pattern.
  119. #
  120. def matchesPattern(self, tree:ParseTree, pattern:ParseTreePattern):
  121. mismatchedNode = self.matchImpl(tree, pattern.patternTree, dict())
  122. return mismatchedNode is None
  123. #
  124. # Compare {@code pattern} matched as rule {@code patternRuleIndex} against
  125. # {@code tree} and return a {@link ParseTreeMatch} object that contains the
  126. # matched elements, or the node at which the match failed.
  127. #
  128. def matchRuleIndex(self, tree:ParseTree, pattern:str, patternRuleIndex:int):
  129. p = self.compileTreePattern(pattern, patternRuleIndex)
  130. return self.matchPattern(tree, p)
  131. #
  132. # Compare {@code pattern} matched against {@code tree} and return a
  133. # {@link ParseTreeMatch} object that contains the matched elements, or the
  134. # node at which the match failed. Pass in a compiled pattern instead of a
  135. # string representation of a tree pattern.
  136. #
  137. def matchPattern(self, tree:ParseTree, pattern:ParseTreePattern):
  138. labels = dict()
  139. mismatchedNode = self.matchImpl(tree, pattern.patternTree, labels)
  140. from antlr4.tree.ParseTreeMatch import ParseTreeMatch
  141. return ParseTreeMatch(tree, pattern, labels, mismatchedNode)
  142. #
  143. # For repeated use of a tree pattern, compile it to a
  144. # {@link ParseTreePattern} using this method.
  145. #
  146. def compileTreePattern(self, pattern:str, patternRuleIndex:int):
  147. tokenList = self.tokenize(pattern)
  148. tokenSrc = ListTokenSource(tokenList)
  149. tokens = CommonTokenStream(tokenSrc)
  150. from antlr4.ParserInterpreter import ParserInterpreter
  151. parserInterp = ParserInterpreter(self.parser.grammarFileName, self.parser.tokenNames,
  152. self.parser.ruleNames, self.parser.getATNWithBypassAlts(),tokens)
  153. tree = None
  154. try:
  155. parserInterp.setErrorHandler(BailErrorStrategy())
  156. tree = parserInterp.parse(patternRuleIndex)
  157. except ParseCancellationException as e:
  158. raise e.cause
  159. except RecognitionException as e:
  160. raise e
  161. except Exception as e:
  162. raise CannotInvokeStartRule(e)
  163. # Make sure tree pattern compilation checks for a complete parse
  164. if tokens.LA(1)!=Token.EOF:
  165. raise StartRuleDoesNotConsumeFullPattern()
  166. from antlr4.tree.ParseTreePattern import ParseTreePattern
  167. return ParseTreePattern(self, pattern, patternRuleIndex, tree)
  168. #
  169. # Recursively walk {@code tree} against {@code patternTree}, filling
  170. # {@code match.}{@link ParseTreeMatch#labels labels}.
  171. #
  172. # @return the first node encountered in {@code tree} which does not match
  173. # a corresponding node in {@code patternTree}, or {@code null} if the match
  174. # was successful. The specific node returned depends on the matching
  175. # algorithm used by the implementation, and may be overridden.
  176. #
  177. def matchImpl(self, tree:ParseTree, patternTree:ParseTree, labels:dict):
  178. if tree is None:
  179. raise Exception("tree cannot be null")
  180. if patternTree is None:
  181. raise Exception("patternTree cannot be null")
  182. # x and <ID>, x and y, or x and x; or could be mismatched types
  183. if isinstance(tree, TerminalNode) and isinstance(patternTree, TerminalNode ):
  184. mismatchedNode = None
  185. # both are tokens and they have same type
  186. if tree.symbol.type == patternTree.symbol.type:
  187. if isinstance( patternTree.symbol, TokenTagToken ): # x and <ID>
  188. tokenTagToken = patternTree.symbol
  189. # track label->list-of-nodes for both token name and label (if any)
  190. self.map(labels, tokenTagToken.tokenName, tree)
  191. if tokenTagToken.label is not None:
  192. self.map(labels, tokenTagToken.label, tree)
  193. elif tree.getText()==patternTree.getText():
  194. # x and x
  195. pass
  196. else:
  197. # x and y
  198. if mismatchedNode is None:
  199. mismatchedNode = tree
  200. else:
  201. if mismatchedNode is None:
  202. mismatchedNode = tree
  203. return mismatchedNode
  204. if isinstance(tree, ParserRuleContext) and isinstance(patternTree, ParserRuleContext):
  205. mismatchedNode = None
  206. # (expr ...) and <expr>
  207. ruleTagToken = self.getRuleTagToken(patternTree)
  208. if ruleTagToken is not None:
  209. m = None
  210. if tree.ruleContext.ruleIndex == patternTree.ruleContext.ruleIndex:
  211. # track label->list-of-nodes for both rule name and label (if any)
  212. self.map(labels, ruleTagToken.ruleName, tree)
  213. if ruleTagToken.label is not None:
  214. self.map(labels, ruleTagToken.label, tree)
  215. else:
  216. if mismatchedNode is None:
  217. mismatchedNode = tree
  218. return mismatchedNode
  219. # (expr ...) and (expr ...)
  220. if tree.getChildCount()!=patternTree.getChildCount():
  221. if mismatchedNode is None:
  222. mismatchedNode = tree
  223. return mismatchedNode
  224. n = tree.getChildCount()
  225. for i in range(0, n):
  226. childMatch = self.matchImpl(tree.getChild(i), patternTree.getChild(i), labels)
  227. if childMatch is not None:
  228. return childMatch
  229. return mismatchedNode
  230. # if nodes aren't both tokens or both rule nodes, can't match
  231. return tree
  232. def map(self, labels, label, tree):
  233. v = labels.get(label, None)
  234. if v is None:
  235. v = list()
  236. labels[label] = v
  237. v.append(tree)
  238. # Is {@code t} {@code (expr <expr>)} subtree?#
  239. def getRuleTagToken(self, tree:ParseTree):
  240. if isinstance( tree, RuleNode ):
  241. if tree.getChildCount()==1 and isinstance(tree.getChild(0), TerminalNode ):
  242. c = tree.getChild(0)
  243. if isinstance( c.symbol, RuleTagToken ):
  244. return c.symbol
  245. return None
  246. def tokenize(self, pattern:str):
  247. # split pattern into chunks: sea (raw input) and islands (<ID>, <expr>)
  248. chunks = self.split(pattern)
  249. # create token stream from text and tags
  250. tokens = list()
  251. for chunk in chunks:
  252. if isinstance( chunk, TagChunk ):
  253. # add special rule token or conjure up new token from name
  254. if chunk.tag[0].isupper():
  255. ttype = self.parser.getTokenType(chunk.tag)
  256. if ttype==Token.INVALID_TYPE:
  257. raise Exception("Unknown token " + str(chunk.tag) + " in pattern: " + pattern)
  258. tokens.append(TokenTagToken(chunk.tag, ttype, chunk.label))
  259. elif chunk.tag[0].islower():
  260. ruleIndex = self.parser.getRuleIndex(chunk.tag)
  261. if ruleIndex==-1:
  262. raise Exception("Unknown rule " + str(chunk.tag) + " in pattern: " + pattern)
  263. ruleImaginaryTokenType = self.parser.getATNWithBypassAlts().ruleToTokenType[ruleIndex]
  264. tokens.append(RuleTagToken(chunk.tag, ruleImaginaryTokenType, chunk.label))
  265. else:
  266. raise Exception("invalid tag: " + str(chunk.tag) + " in pattern: " + pattern)
  267. else:
  268. self.lexer.setInputStream(InputStream(chunk.text))
  269. t = self.lexer.nextToken()
  270. while t.type!=Token.EOF:
  271. tokens.append(t)
  272. t = self.lexer.nextToken()
  273. return tokens
  274. # Split {@code <ID> = <e:expr> ;} into 4 chunks for tokenizing by {@link #tokenize}.#
  275. def split(self, pattern:str):
  276. p = 0
  277. n = len(pattern)
  278. chunks = list()
  279. # find all start and stop indexes first, then collect
  280. starts = list()
  281. stops = list()
  282. while p < n :
  283. if p == pattern.find(self.escape + self.start, p):
  284. p += len(self.escape) + len(self.start)
  285. elif p == pattern.find(self.escape + self.stop, p):
  286. p += len(self.escape) + len(self.stop)
  287. elif p == pattern.find(self.start, p):
  288. starts.append(p)
  289. p += len(self.start)
  290. elif p == pattern.find(self.stop, p):
  291. stops.append(p)
  292. p += len(self.stop)
  293. else:
  294. p += 1
  295. nt = len(starts)
  296. if nt > len(stops):
  297. raise Exception("unterminated tag in pattern: " + pattern)
  298. if nt < len(stops):
  299. raise Exception("missing start tag in pattern: " + pattern)
  300. for i in range(0, nt):
  301. if starts[i] >= stops[i]:
  302. raise Exception("tag delimiters out of order in pattern: " + pattern)
  303. # collect into chunks now
  304. if nt==0:
  305. chunks.append(TextChunk(pattern))
  306. if nt>0 and starts[0]>0: # copy text up to first tag into chunks
  307. text = pattern[0:starts[0]]
  308. chunks.add(TextChunk(text))
  309. for i in range(0, nt):
  310. # copy inside of <tag>
  311. tag = pattern[starts[i] + len(self.start) : stops[i]]
  312. ruleOrToken = tag
  313. label = None
  314. colon = tag.find(':')
  315. if colon >= 0:
  316. label = tag[0:colon]
  317. ruleOrToken = tag[colon+1 : len(tag)]
  318. chunks.append(TagChunk(label, ruleOrToken))
  319. if i+1 < len(starts):
  320. # copy from end of <tag> to start of next
  321. text = pattern[stops[i] + len(self.stop) : starts[i + 1]]
  322. chunks.append(TextChunk(text))
  323. if nt > 0 :
  324. afterLastTag = stops[nt - 1] + len(self.stop)
  325. if afterLastTag < n : # copy text from end of last tag to end
  326. text = pattern[afterLastTag : n]
  327. chunks.append(TextChunk(text))
  328. # strip out the escape sequences from text chunks but not tags
  329. for i in range(0, len(chunks)):
  330. c = chunks[i]
  331. if isinstance( c, TextChunk ):
  332. unescaped = c.text.replace(self.escape, "")
  333. if len(unescaped) < len(c.text):
  334. chunks[i] = TextChunk(unescaped)
  335. return chunks