tree_templates.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. """This module defines utilities for matching and translation tree templates.
  2. A tree templates is a tree that contains nodes that are template variables.
  3. """
  4. from typing import Union, Optional, Mapping, Dict, Tuple, Iterator
  5. from lark import Tree, Transformer
  6. from lark.exceptions import MissingVariableError
  7. Branch = Union[Tree[str], str]
  8. TreeOrCode = Union[Tree[str], str]
  9. MatchResult = Dict[str, Tree]
  10. _TEMPLATE_MARKER = '$'
  11. class TemplateConf:
  12. """Template Configuration
  13. Allows customization for different uses of Template
  14. parse() must return a Tree instance.
  15. """
  16. def __init__(self, parse=None):
  17. self._parse = parse
  18. def test_var(self, var: Union[Tree[str], str]) -> Optional[str]:
  19. """Given a tree node, if it is a template variable return its name. Otherwise, return None.
  20. This method may be overridden for customization
  21. Parameters:
  22. var: Tree | str - The tree node to test
  23. """
  24. if isinstance(var, str):
  25. return _get_template_name(var)
  26. if (
  27. isinstance(var, Tree)
  28. and var.data == "var"
  29. and len(var.children) > 0
  30. and isinstance(var.children[0], str)
  31. ):
  32. return _get_template_name(var.children[0])
  33. return None
  34. def _get_tree(self, template: TreeOrCode) -> Tree[str]:
  35. if isinstance(template, str):
  36. assert self._parse
  37. template = self._parse(template)
  38. if not isinstance(template, Tree):
  39. raise TypeError("template parser must return a Tree instance")
  40. return template
  41. def __call__(self, template: Tree[str]) -> 'Template':
  42. return Template(template, conf=self)
  43. def _match_tree_template(self, template: TreeOrCode, tree: Branch) -> Optional[MatchResult]:
  44. """Returns dict of {var: match} if found a match, else None
  45. """
  46. template_var = self.test_var(template)
  47. if template_var:
  48. if not isinstance(tree, Tree):
  49. raise TypeError(f"Template variables can only match Tree instances. Not {tree!r}")
  50. return {template_var: tree}
  51. if isinstance(template, str):
  52. if template == tree:
  53. return {}
  54. return None
  55. assert isinstance(template, Tree) and isinstance(tree, Tree), f"template={template} tree={tree}"
  56. if template.data == tree.data and len(template.children) == len(tree.children):
  57. res = {}
  58. for t1, t2 in zip(template.children, tree.children):
  59. matches = self._match_tree_template(t1, t2)
  60. if matches is None:
  61. return None
  62. res.update(matches)
  63. return res
  64. return None
  65. class _ReplaceVars(Transformer[str, Tree[str]]):
  66. def __init__(self, conf: TemplateConf, vars: Mapping[str, Tree[str]]) -> None:
  67. super().__init__()
  68. self._conf = conf
  69. self._vars = vars
  70. def __default__(self, data, children, meta) -> Tree[str]:
  71. tree = super().__default__(data, children, meta)
  72. var = self._conf.test_var(tree)
  73. if var:
  74. try:
  75. return self._vars[var]
  76. except KeyError:
  77. raise MissingVariableError(f"No mapping for template variable ({var})")
  78. return tree
  79. class Template:
  80. """Represents a tree template, tied to a specific configuration
  81. A tree template is a tree that contains nodes that are template variables.
  82. Those variables will match any tree.
  83. (future versions may support annotations on the variables, to allow more complex templates)
  84. """
  85. def __init__(self, tree: Tree[str], conf: TemplateConf = TemplateConf()):
  86. self.conf = conf
  87. self.tree = conf._get_tree(tree)
  88. def match(self, tree: TreeOrCode) -> Optional[MatchResult]:
  89. """Match a tree template to a tree.
  90. A tree template without variables will only match ``tree`` if it is equal to the template.
  91. Parameters:
  92. tree (Tree): The tree to match to the template
  93. Returns:
  94. Optional[Dict[str, Tree]]: If match is found, returns a dictionary mapping
  95. template variable names to their matching tree nodes.
  96. If no match was found, returns None.
  97. """
  98. tree = self.conf._get_tree(tree)
  99. return self.conf._match_tree_template(self.tree, tree)
  100. def search(self, tree: TreeOrCode) -> Iterator[Tuple[Tree[str], MatchResult]]:
  101. """Search for all occurrences of the tree template inside ``tree``.
  102. """
  103. tree = self.conf._get_tree(tree)
  104. for subtree in tree.iter_subtrees():
  105. res = self.match(subtree)
  106. if res:
  107. yield subtree, res
  108. def apply_vars(self, vars: Mapping[str, Tree[str]]) -> Tree[str]:
  109. """Apply vars to the template tree
  110. """
  111. return _ReplaceVars(self.conf, vars).transform(self.tree)
  112. def translate(t1: Template, t2: Template, tree: TreeOrCode):
  113. """Search tree and translate each occurrence of t1 into t2.
  114. """
  115. tree = t1.conf._get_tree(tree) # ensure it's a tree, parse if necessary and possible
  116. for subtree, vars in t1.search(tree):
  117. res = t2.apply_vars(vars)
  118. subtree.set(res.data, res.children)
  119. return tree
  120. class TemplateTranslator:
  121. """Utility class for translating a collection of patterns
  122. """
  123. def __init__(self, translations: Mapping[Template, Template]):
  124. assert all(isinstance(k, Template) and isinstance(v, Template) for k, v in translations.items())
  125. self.translations = translations
  126. def translate(self, tree: Tree[str]):
  127. for k, v in self.translations.items():
  128. tree = translate(k, v, tree)
  129. return tree
  130. def _get_template_name(value: str) -> Optional[str]:
  131. return value.lstrip(_TEMPLATE_MARKER) if value.startswith(_TEMPLATE_MARKER) else None