cyk.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. """This module implements a CYK parser."""
  2. # Author: https://github.com/ehudt (2018)
  3. #
  4. # Adapted by Erez
  5. from collections import defaultdict
  6. import itertools
  7. from ..exceptions import ParseError
  8. from ..lexer import Token
  9. from ..tree import Tree
  10. from ..grammar import Terminal as T, NonTerminal as NT, Symbol
  11. def match(t, s):
  12. assert isinstance(t, T)
  13. return t.name == s.type
  14. class Rule:
  15. """Context-free grammar rule."""
  16. def __init__(self, lhs, rhs, weight, alias):
  17. super(Rule, self).__init__()
  18. assert isinstance(lhs, NT), lhs
  19. assert all(isinstance(x, NT) or isinstance(x, T) for x in rhs), rhs
  20. self.lhs = lhs
  21. self.rhs = rhs
  22. self.weight = weight
  23. self.alias = alias
  24. def __str__(self):
  25. return '%s -> %s' % (str(self.lhs), ' '.join(str(x) for x in self.rhs))
  26. def __repr__(self):
  27. return str(self)
  28. def __hash__(self):
  29. return hash((self.lhs, tuple(self.rhs)))
  30. def __eq__(self, other):
  31. return self.lhs == other.lhs and self.rhs == other.rhs
  32. def __ne__(self, other):
  33. return not (self == other)
  34. class Grammar:
  35. """Context-free grammar."""
  36. def __init__(self, rules):
  37. self.rules = frozenset(rules)
  38. def __eq__(self, other):
  39. return self.rules == other.rules
  40. def __str__(self):
  41. return '\n' + '\n'.join(sorted(repr(x) for x in self.rules)) + '\n'
  42. def __repr__(self):
  43. return str(self)
  44. # Parse tree data structures
  45. class RuleNode:
  46. """A node in the parse tree, which also contains the full rhs rule."""
  47. def __init__(self, rule, children, weight=0):
  48. self.rule = rule
  49. self.children = children
  50. self.weight = weight
  51. def __repr__(self):
  52. return 'RuleNode(%s, [%s])' % (repr(self.rule.lhs), ', '.join(str(x) for x in self.children))
  53. class Parser:
  54. """Parser wrapper."""
  55. def __init__(self, rules):
  56. super(Parser, self).__init__()
  57. self.orig_rules = {rule: rule for rule in rules}
  58. rules = [self._to_rule(rule) for rule in rules]
  59. self.grammar = to_cnf(Grammar(rules))
  60. def _to_rule(self, lark_rule):
  61. """Converts a lark rule, (lhs, rhs, callback, options), to a Rule."""
  62. assert isinstance(lark_rule.origin, NT)
  63. assert all(isinstance(x, Symbol) for x in lark_rule.expansion)
  64. return Rule(
  65. lark_rule.origin, lark_rule.expansion,
  66. weight=lark_rule.options.priority if lark_rule.options.priority else 0,
  67. alias=lark_rule)
  68. def parse(self, tokenized, start): # pylint: disable=invalid-name
  69. """Parses input, which is a list of tokens."""
  70. assert start
  71. start = NT(start)
  72. table, trees = _parse(tokenized, self.grammar)
  73. # Check if the parse succeeded.
  74. if all(r.lhs != start for r in table[(0, len(tokenized) - 1)]):
  75. raise ParseError('Parsing failed.')
  76. parse = trees[(0, len(tokenized) - 1)][start]
  77. return self._to_tree(revert_cnf(parse))
  78. def _to_tree(self, rule_node):
  79. """Converts a RuleNode parse tree to a lark Tree."""
  80. orig_rule = self.orig_rules[rule_node.rule.alias]
  81. children = []
  82. for child in rule_node.children:
  83. if isinstance(child, RuleNode):
  84. children.append(self._to_tree(child))
  85. else:
  86. assert isinstance(child.name, Token)
  87. children.append(child.name)
  88. t = Tree(orig_rule.origin, children)
  89. t.rule=orig_rule
  90. return t
  91. def print_parse(node, indent=0):
  92. if isinstance(node, RuleNode):
  93. print(' ' * (indent * 2) + str(node.rule.lhs))
  94. for child in node.children:
  95. print_parse(child, indent + 1)
  96. else:
  97. print(' ' * (indent * 2) + str(node.s))
  98. def _parse(s, g):
  99. """Parses sentence 's' using CNF grammar 'g'."""
  100. # The CYK table. Indexed with a 2-tuple: (start pos, end pos)
  101. table = defaultdict(set)
  102. # Top-level structure is similar to the CYK table. Each cell is a dict from
  103. # rule name to the best (lightest) tree for that rule.
  104. trees = defaultdict(dict)
  105. # Populate base case with existing terminal production rules
  106. for i, w in enumerate(s):
  107. for terminal, rules in g.terminal_rules.items():
  108. if match(terminal, w):
  109. for rule in rules:
  110. table[(i, i)].add(rule)
  111. if (rule.lhs not in trees[(i, i)] or
  112. rule.weight < trees[(i, i)][rule.lhs].weight):
  113. trees[(i, i)][rule.lhs] = RuleNode(rule, [T(w)], weight=rule.weight)
  114. # Iterate over lengths of sub-sentences
  115. for l in range(2, len(s) + 1):
  116. # Iterate over sub-sentences with the given length
  117. for i in range(len(s) - l + 1):
  118. # Choose partition of the sub-sentence in [1, l)
  119. for p in range(i + 1, i + l):
  120. span1 = (i, p - 1)
  121. span2 = (p, i + l - 1)
  122. for r1, r2 in itertools.product(table[span1], table[span2]):
  123. for rule in g.nonterminal_rules.get((r1.lhs, r2.lhs), []):
  124. table[(i, i + l - 1)].add(rule)
  125. r1_tree = trees[span1][r1.lhs]
  126. r2_tree = trees[span2][r2.lhs]
  127. rule_total_weight = rule.weight + r1_tree.weight + r2_tree.weight
  128. if (rule.lhs not in trees[(i, i + l - 1)]
  129. or rule_total_weight < trees[(i, i + l - 1)][rule.lhs].weight):
  130. trees[(i, i + l - 1)][rule.lhs] = RuleNode(rule, [r1_tree, r2_tree], weight=rule_total_weight)
  131. return table, trees
  132. # This section implements context-free grammar converter to Chomsky normal form.
  133. # It also implements a conversion of parse trees from its CNF to the original
  134. # grammar.
  135. # Overview:
  136. # Applies the following operations in this order:
  137. # * TERM: Eliminates non-solitary terminals from all rules
  138. # * BIN: Eliminates rules with more than 2 symbols on their right-hand-side.
  139. # * UNIT: Eliminates non-terminal unit rules
  140. #
  141. # The following grammar characteristics aren't featured:
  142. # * Start symbol appears on RHS
  143. # * Empty rules (epsilon rules)
  144. class CnfWrapper:
  145. """CNF wrapper for grammar.
  146. Validates that the input grammar is CNF and provides helper data structures.
  147. """
  148. def __init__(self, grammar):
  149. super(CnfWrapper, self).__init__()
  150. self.grammar = grammar
  151. self.rules = grammar.rules
  152. self.terminal_rules = defaultdict(list)
  153. self.nonterminal_rules = defaultdict(list)
  154. for r in self.rules:
  155. # Validate that the grammar is CNF and populate auxiliary data structures.
  156. assert isinstance(r.lhs, NT), r
  157. if len(r.rhs) not in [1, 2]:
  158. raise ParseError("CYK doesn't support empty rules")
  159. if len(r.rhs) == 1 and isinstance(r.rhs[0], T):
  160. self.terminal_rules[r.rhs[0]].append(r)
  161. elif len(r.rhs) == 2 and all(isinstance(x, NT) for x in r.rhs):
  162. self.nonterminal_rules[tuple(r.rhs)].append(r)
  163. else:
  164. assert False, r
  165. def __eq__(self, other):
  166. return self.grammar == other.grammar
  167. def __repr__(self):
  168. return repr(self.grammar)
  169. class UnitSkipRule(Rule):
  170. """A rule that records NTs that were skipped during transformation."""
  171. def __init__(self, lhs, rhs, skipped_rules, weight, alias):
  172. super(UnitSkipRule, self).__init__(lhs, rhs, weight, alias)
  173. self.skipped_rules = skipped_rules
  174. def __eq__(self, other):
  175. return isinstance(other, type(self)) and self.skipped_rules == other.skipped_rules
  176. __hash__ = Rule.__hash__
  177. def build_unit_skiprule(unit_rule, target_rule):
  178. skipped_rules = []
  179. if isinstance(unit_rule, UnitSkipRule):
  180. skipped_rules += unit_rule.skipped_rules
  181. skipped_rules.append(target_rule)
  182. if isinstance(target_rule, UnitSkipRule):
  183. skipped_rules += target_rule.skipped_rules
  184. return UnitSkipRule(unit_rule.lhs, target_rule.rhs, skipped_rules,
  185. weight=unit_rule.weight + target_rule.weight, alias=unit_rule.alias)
  186. def get_any_nt_unit_rule(g):
  187. """Returns a non-terminal unit rule from 'g', or None if there is none."""
  188. for rule in g.rules:
  189. if len(rule.rhs) == 1 and isinstance(rule.rhs[0], NT):
  190. return rule
  191. return None
  192. def _remove_unit_rule(g, rule):
  193. """Removes 'rule' from 'g' without changing the language produced by 'g'."""
  194. new_rules = [x for x in g.rules if x != rule]
  195. refs = [x for x in g.rules if x.lhs == rule.rhs[0]]
  196. new_rules += [build_unit_skiprule(rule, ref) for ref in refs]
  197. return Grammar(new_rules)
  198. def _split(rule):
  199. """Splits a rule whose len(rhs) > 2 into shorter rules."""
  200. rule_str = str(rule.lhs) + '__' + '_'.join(str(x) for x in rule.rhs)
  201. rule_name = '__SP_%s' % (rule_str) + '_%d'
  202. yield Rule(rule.lhs, [rule.rhs[0], NT(rule_name % 1)], weight=rule.weight, alias=rule.alias)
  203. for i in range(1, len(rule.rhs) - 2):
  204. yield Rule(NT(rule_name % i), [rule.rhs[i], NT(rule_name % (i + 1))], weight=0, alias='Split')
  205. yield Rule(NT(rule_name % (len(rule.rhs) - 2)), rule.rhs[-2:], weight=0, alias='Split')
  206. def _term(g):
  207. """Applies the TERM rule on 'g' (see top comment)."""
  208. all_t = {x for rule in g.rules for x in rule.rhs if isinstance(x, T)}
  209. t_rules = {t: Rule(NT('__T_%s' % str(t)), [t], weight=0, alias='Term') for t in all_t}
  210. new_rules = []
  211. for rule in g.rules:
  212. if len(rule.rhs) > 1 and any(isinstance(x, T) for x in rule.rhs):
  213. new_rhs = [t_rules[x].lhs if isinstance(x, T) else x for x in rule.rhs]
  214. new_rules.append(Rule(rule.lhs, new_rhs, weight=rule.weight, alias=rule.alias))
  215. new_rules.extend(v for k, v in t_rules.items() if k in rule.rhs)
  216. else:
  217. new_rules.append(rule)
  218. return Grammar(new_rules)
  219. def _bin(g):
  220. """Applies the BIN rule to 'g' (see top comment)."""
  221. new_rules = []
  222. for rule in g.rules:
  223. if len(rule.rhs) > 2:
  224. new_rules += _split(rule)
  225. else:
  226. new_rules.append(rule)
  227. return Grammar(new_rules)
  228. def _unit(g):
  229. """Applies the UNIT rule to 'g' (see top comment)."""
  230. nt_unit_rule = get_any_nt_unit_rule(g)
  231. while nt_unit_rule:
  232. g = _remove_unit_rule(g, nt_unit_rule)
  233. nt_unit_rule = get_any_nt_unit_rule(g)
  234. return g
  235. def to_cnf(g):
  236. """Creates a CNF grammar from a general context-free grammar 'g'."""
  237. g = _unit(_bin(_term(g)))
  238. return CnfWrapper(g)
  239. def unroll_unit_skiprule(lhs, orig_rhs, skipped_rules, children, weight, alias):
  240. if not skipped_rules:
  241. return RuleNode(Rule(lhs, orig_rhs, weight=weight, alias=alias), children, weight=weight)
  242. else:
  243. weight = weight - skipped_rules[0].weight
  244. return RuleNode(
  245. Rule(lhs, [skipped_rules[0].lhs], weight=weight, alias=alias), [
  246. unroll_unit_skiprule(skipped_rules[0].lhs, orig_rhs,
  247. skipped_rules[1:], children,
  248. skipped_rules[0].weight, skipped_rules[0].alias)
  249. ], weight=weight)
  250. def revert_cnf(node):
  251. """Reverts a parse tree (RuleNode) to its original non-CNF form (Node)."""
  252. if isinstance(node, T):
  253. return node
  254. # Reverts TERM rule.
  255. if node.rule.lhs.name.startswith('__T_'):
  256. return node.children[0]
  257. else:
  258. children = []
  259. for child in map(revert_cnf, node.children):
  260. # Reverts BIN rule.
  261. if isinstance(child, RuleNode) and child.rule.lhs.name.startswith('__SP_'):
  262. children += child.children
  263. else:
  264. children.append(child)
  265. # Reverts UNIT rule.
  266. if isinstance(node.rule, UnitSkipRule):
  267. return unroll_unit_skiprule(node.rule.lhs, node.rule.rhs,
  268. node.rule.skipped_rules, children,
  269. node.rule.weight, node.rule.alias)
  270. else:
  271. return RuleNode(node.rule, children)