Trees.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. #
  2. # Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
  3. # Use of this file is governed by the BSD 3-clause license that
  4. # can be found in the LICENSE.txt file in the project root.
  5. #
  6. # A set of utility routines useful for all kinds of ANTLR trees.#
  7. from io import StringIO
  8. from antlr4.Token import Token
  9. from antlr4.Utils import escapeWhitespace
  10. from antlr4.tree.Tree import RuleNode, ErrorNode, TerminalNode, Tree, ParseTree
  11. # need forward declaration
  12. Parser = None
  13. class Trees(object):
  14. # Print out a whole tree in LISP form. {@link #getNodeText} is used on the
  15. # node payloads to get the text for the nodes. Detect
  16. # parse trees and extract data appropriately.
  17. @classmethod
  18. def toStringTree(cls, t:Tree, ruleNames:list=None, recog:Parser=None):
  19. if recog is not None:
  20. ruleNames = recog.ruleNames
  21. s = escapeWhitespace(cls.getNodeText(t, ruleNames), False)
  22. if t.getChildCount()==0:
  23. return s
  24. with StringIO() as buf:
  25. buf.write("(")
  26. buf.write(s)
  27. buf.write(' ')
  28. for i in range(0, t.getChildCount()):
  29. if i > 0:
  30. buf.write(' ')
  31. buf.write(cls.toStringTree(t.getChild(i), ruleNames))
  32. buf.write(")")
  33. return buf.getvalue()
  34. @classmethod
  35. def getNodeText(cls, t:Tree, ruleNames:list=None, recog:Parser=None):
  36. if recog is not None:
  37. ruleNames = recog.ruleNames
  38. if ruleNames is not None:
  39. if isinstance(t, RuleNode):
  40. if t.getAltNumber()!=0: # should use ATN.INVALID_ALT_NUMBER but won't compile
  41. return ruleNames[t.getRuleIndex()]+":"+str(t.getAltNumber())
  42. return ruleNames[t.getRuleIndex()]
  43. elif isinstance( t, ErrorNode):
  44. return str(t)
  45. elif isinstance(t, TerminalNode):
  46. if t.symbol is not None:
  47. return t.symbol.text
  48. # no recog for rule names
  49. payload = t.getPayload()
  50. if isinstance(payload, Token ):
  51. return payload.text
  52. return str(t.getPayload())
  53. # Return ordered list of all children of this node
  54. @classmethod
  55. def getChildren(cls, t:Tree):
  56. return [ t.getChild(i) for i in range(0, t.getChildCount()) ]
  57. # Return a list of all ancestors of this node. The first node of
  58. # list is the root and the last is the parent of this node.
  59. #
  60. @classmethod
  61. def getAncestors(cls, t:Tree):
  62. ancestors = []
  63. t = t.getParent()
  64. while t is not None:
  65. ancestors.insert(0, t) # insert at start
  66. t = t.getParent()
  67. return ancestors
  68. @classmethod
  69. def findAllTokenNodes(cls, t:ParseTree, ttype:int):
  70. return cls.findAllNodes(t, ttype, True)
  71. @classmethod
  72. def findAllRuleNodes(cls, t:ParseTree, ruleIndex:int):
  73. return cls.findAllNodes(t, ruleIndex, False)
  74. @classmethod
  75. def findAllNodes(cls, t:ParseTree, index:int, findTokens:bool):
  76. nodes = []
  77. cls._findAllNodes(t, index, findTokens, nodes)
  78. return nodes
  79. @classmethod
  80. def _findAllNodes(cls, t:ParseTree, index:int, findTokens:bool, nodes:list):
  81. from antlr4.ParserRuleContext import ParserRuleContext
  82. # check this node (the root) first
  83. if findTokens and isinstance(t, TerminalNode):
  84. if t.symbol.type==index:
  85. nodes.append(t)
  86. elif not findTokens and isinstance(t, ParserRuleContext):
  87. if t.ruleIndex == index:
  88. nodes.append(t)
  89. # check children
  90. for i in range(0, t.getChildCount()):
  91. cls._findAllNodes(t.getChild(i), index, findTokens, nodes)
  92. @classmethod
  93. def descendants(cls, t:ParseTree):
  94. nodes = [t]
  95. for i in range(0, t.getChildCount()):
  96. nodes.extend(cls.descendants(t.getChild(i)))
  97. return nodes