tree_matcher.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. """Tree matcher based on Lark grammar"""
  2. import re
  3. from typing import List, Dict
  4. from collections import defaultdict
  5. from . import Tree, Token, Lark
  6. from .common import ParserConf
  7. from .exceptions import ConfigurationError
  8. from .parsers import earley
  9. from .grammar import Rule, Terminal, NonTerminal
  10. def is_discarded_terminal(t):
  11. return t.is_term and t.filter_out
  12. class _MakeTreeMatch:
  13. def __init__(self, name, expansion):
  14. self.name = name
  15. self.expansion = expansion
  16. def __call__(self, args):
  17. t = Tree(self.name, args)
  18. t.meta.match_tree = True
  19. t.meta.orig_expansion = self.expansion
  20. return t
  21. def _best_from_group(seq, group_key, cmp_key):
  22. d = {}
  23. for item in seq:
  24. key = group_key(item)
  25. if key in d:
  26. v1 = cmp_key(item)
  27. v2 = cmp_key(d[key])
  28. if v2 > v1:
  29. d[key] = item
  30. else:
  31. d[key] = item
  32. return list(d.values())
  33. def _best_rules_from_group(rules: List[Rule]) -> List[Rule]:
  34. rules = _best_from_group(rules, lambda r: r, lambda r: -len(r.expansion))
  35. rules.sort(key=lambda r: len(r.expansion))
  36. return rules
  37. def _match(term, token):
  38. if isinstance(token, Tree):
  39. name, _args = parse_rulename(term.name)
  40. return token.data == name
  41. elif isinstance(token, Token):
  42. return term == Terminal(token.type)
  43. assert False, (term, token)
  44. def make_recons_rule(origin, expansion, old_expansion):
  45. return Rule(origin, expansion, alias=_MakeTreeMatch(origin.name, old_expansion))
  46. def make_recons_rule_to_term(origin, term):
  47. return make_recons_rule(origin, [Terminal(term.name)], [term])
  48. def parse_rulename(s):
  49. "Parse rule names that may contain a template syntax (like rule{a, b, ...})"
  50. name, args_str = re.match(r'(\w+)(?:{(.+)})?', s).groups()
  51. args = args_str and [a.strip() for a in args_str.split(',')]
  52. return name, args
  53. class ChildrenLexer:
  54. def __init__(self, children):
  55. self.children = children
  56. def lex(self, parser_state):
  57. return self.children
  58. class TreeMatcher:
  59. """Match the elements of a tree node, based on an ontology
  60. provided by a Lark grammar.
  61. Supports templates and inlined rules (`rule{a, b,..}` and `_rule`)
  62. Initialize with an instance of Lark.
  63. """
  64. rules_for_root: Dict[str, List[Rule]]
  65. rules: List[Rule]
  66. parser: Lark
  67. def __init__(self, parser: Lark):
  68. # XXX TODO calling compile twice returns different results!
  69. assert not parser.options.maybe_placeholders
  70. if parser.options.postlex and parser.options.postlex.always_accept:
  71. # If postlexer's always_accept is used, we need to recompile the grammar with empty terminals-to-keep
  72. if not hasattr(parser, 'grammar'):
  73. raise ConfigurationError('Source grammar not available from cached parser, use cache_grammar=True'
  74. if parser.options.cache else "Source grammar not available!")
  75. self.tokens, rules, _extra = parser.grammar.compile(parser.options.start, set())
  76. else:
  77. self.tokens = list(parser.terminals)
  78. rules = list(parser.rules)
  79. self.rules_for_root = defaultdict(list)
  80. self.rules = list(self._build_recons_rules(rules))
  81. self.rules.reverse()
  82. # Choose the best rule from each group of {rule => [rule.alias]}, since we only really need one derivation.
  83. self.rules = _best_rules_from_group(self.rules)
  84. self.parser = parser
  85. self._parser_cache: Dict[str, earley.Parser] = {}
  86. def _build_recons_rules(self, rules: List[Rule]):
  87. "Convert tree-parsing/construction rules to tree-matching rules"
  88. expand1s = {r.origin for r in rules if r.options.expand1}
  89. aliases = defaultdict(list)
  90. for r in rules:
  91. if r.alias:
  92. aliases[r.origin].append(r.alias)
  93. rule_names = {r.origin for r in rules}
  94. nonterminals = {sym for sym in rule_names
  95. if sym.name.startswith('_') or sym in expand1s or sym in aliases}
  96. seen = set()
  97. for r in rules:
  98. recons_exp = [sym if sym in nonterminals else Terminal(sym.name)
  99. for sym in r.expansion if not is_discarded_terminal(sym)]
  100. # Skip self-recursive constructs
  101. if recons_exp == [r.origin] and r.alias is None:
  102. continue
  103. sym = NonTerminal(r.alias) if r.alias else r.origin
  104. rule = make_recons_rule(sym, recons_exp, r.expansion)
  105. if sym in expand1s and len(recons_exp) != 1:
  106. self.rules_for_root[sym.name].append(rule)
  107. if sym.name not in seen:
  108. yield make_recons_rule_to_term(sym, sym)
  109. seen.add(sym.name)
  110. else:
  111. if sym.name.startswith('_') or sym in expand1s:
  112. yield rule
  113. else:
  114. self.rules_for_root[sym.name].append(rule)
  115. for origin, rule_aliases in aliases.items():
  116. for alias in rule_aliases:
  117. yield make_recons_rule_to_term(origin, NonTerminal(alias))
  118. yield make_recons_rule_to_term(origin, origin)
  119. def match_tree(self, tree: Tree, rulename: str) -> Tree:
  120. """Match the elements of `tree` to the symbols of rule `rulename`.
  121. Parameters:
  122. tree (Tree): the tree node to match
  123. rulename (str): The expected full rule name (including template args)
  124. Returns:
  125. Tree: an unreduced tree that matches `rulename`
  126. Raises:
  127. UnexpectedToken: If no match was found.
  128. Note:
  129. It's the callers' responsibility to match the tree recursively.
  130. """
  131. if rulename:
  132. # validate
  133. name, _args = parse_rulename(rulename)
  134. assert tree.data == name
  135. else:
  136. rulename = tree.data
  137. # TODO: ambiguity?
  138. try:
  139. parser = self._parser_cache[rulename]
  140. except KeyError:
  141. rules = self.rules + _best_rules_from_group(self.rules_for_root[rulename])
  142. # TODO pass callbacks through dict, instead of alias?
  143. callbacks = {rule: rule.alias for rule in rules}
  144. conf = ParserConf(rules, callbacks, [rulename]) # type: ignore[arg-type]
  145. parser = earley.Parser(self.parser.lexer_conf, conf, _match, resolve_ambiguity=True)
  146. self._parser_cache[rulename] = parser
  147. # find a full derivation
  148. unreduced_tree: Tree = parser.parse(ChildrenLexer(tree.children), rulename)
  149. assert unreduced_tree.data == rulename
  150. return unreduced_tree