grammar_parser.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import re
  2. import threading
  3. from typing import Any
  4. from antlr4 import CommonTokenStream, InputStream, ParserRuleContext
  5. from antlr4.error.ErrorListener import ErrorListener
  6. from .errors import GrammarParseError
  7. # Import from visitor in order to check the presence of generated grammar files
  8. # files in a single place.
  9. from .grammar_visitor import ( # type: ignore
  10. OmegaConfGrammarLexer,
  11. OmegaConfGrammarParser,
  12. )
  13. # Used to cache grammar objects to avoid re-creating them on each call to `parse()`.
  14. # We use a per-thread cache to make it thread-safe.
  15. _grammar_cache = threading.local()
  16. # Build regex pattern to efficiently identify typical interpolations.
  17. # See test `test_match_simple_interpolation_pattern` for examples.
  18. _config_key = r"[$\w]+" # foo, $0, $bar, $foo_$bar123$
  19. _key_maybe_brackets = f"{_config_key}|\\[{_config_key}\\]" # foo, [foo], [$bar]
  20. _node_access = f"\\.{_key_maybe_brackets}" # .foo, [foo], [$bar]
  21. _node_path = f"(\\.)*({_key_maybe_brackets})({_node_access})*" # [foo].bar, .foo[bar]
  22. _node_inter = f"\\${{\\s*{_node_path}\\s*}}" # node interpolation ${foo.bar}
  23. _id = "[a-zA-Z_][\\w\\-]*" # foo, foo_bar, foo-bar, abc123
  24. _resolver_name = f"({_id}(\\.{_id})*)?" # foo, ns.bar3, ns_1.ns_2.b0z
  25. _arg = r"[a-zA-Z_0-9/\-\+.$%*@?|]+" # string representing a resolver argument
  26. _args = f"{_arg}(\\s*,\\s*{_arg})*" # list of resolver arguments
  27. _resolver_inter = f"\\${{\\s*{_resolver_name}\\s*:\\s*{_args}?\\s*}}" # ${foo:bar}
  28. _inter = f"({_node_inter}|{_resolver_inter})" # any kind of interpolation
  29. _outer = "([^$]|\\$(?!{))+" # any character except $ (unless not followed by {)
  30. SIMPLE_INTERPOLATION_PATTERN = re.compile(
  31. f"({_outer})?({_inter}({_outer})?)+$", flags=re.ASCII
  32. )
  33. # NOTE: SIMPLE_INTERPOLATION_PATTERN must not generate false positive matches:
  34. # it must not accept anything that isn't a valid interpolation (per the
  35. # interpolation grammar defined in `omegaconf/grammar/*.g4`).
  36. class OmegaConfErrorListener(ErrorListener): # type: ignore
  37. def syntaxError(
  38. self,
  39. recognizer: Any,
  40. offending_symbol: Any,
  41. line: Any,
  42. column: Any,
  43. msg: Any,
  44. e: Any,
  45. ) -> None:
  46. raise GrammarParseError(str(e) if msg is None else msg) from e
  47. def reportAmbiguity(
  48. self,
  49. recognizer: Any,
  50. dfa: Any,
  51. startIndex: Any,
  52. stopIndex: Any,
  53. exact: Any,
  54. ambigAlts: Any,
  55. configs: Any,
  56. ) -> None:
  57. raise GrammarParseError("ANTLR error: Ambiguity") # pragma: no cover
  58. def reportAttemptingFullContext(
  59. self,
  60. recognizer: Any,
  61. dfa: Any,
  62. startIndex: Any,
  63. stopIndex: Any,
  64. conflictingAlts: Any,
  65. configs: Any,
  66. ) -> None:
  67. # Note: for now we raise an error to be safe. However this is mostly a
  68. # performance warning, so in the future this may be relaxed if we need
  69. # to change the grammar in such a way that this warning cannot be
  70. # avoided (another option would be to switch to SLL parsing mode).
  71. raise GrammarParseError(
  72. "ANTLR error: Attempting Full Context"
  73. ) # pragma: no cover
  74. def reportContextSensitivity(
  75. self,
  76. recognizer: Any,
  77. dfa: Any,
  78. startIndex: Any,
  79. stopIndex: Any,
  80. prediction: Any,
  81. configs: Any,
  82. ) -> None:
  83. raise GrammarParseError("ANTLR error: ContextSensitivity") # pragma: no cover
  84. def parse(
  85. value: str, parser_rule: str = "configValue", lexer_mode: str = "DEFAULT_MODE"
  86. ) -> ParserRuleContext:
  87. """
  88. Parse interpolated string `value` (and return the parse tree).
  89. """
  90. l_mode = getattr(OmegaConfGrammarLexer, lexer_mode)
  91. istream = InputStream(value)
  92. cached = getattr(_grammar_cache, "data", None)
  93. if cached is None:
  94. error_listener = OmegaConfErrorListener()
  95. lexer = OmegaConfGrammarLexer(istream)
  96. lexer.removeErrorListeners()
  97. lexer.addErrorListener(error_listener)
  98. lexer.mode(l_mode)
  99. token_stream = CommonTokenStream(lexer)
  100. parser = OmegaConfGrammarParser(token_stream)
  101. parser.removeErrorListeners()
  102. parser.addErrorListener(error_listener)
  103. # The two lines below could be enabled in the future if we decide to switch
  104. # to SLL prediction mode. Warning though, it has not been fully tested yet!
  105. # from antlr4 import PredictionMode
  106. # parser._interp.predictionMode = PredictionMode.SLL
  107. # Note that although the input stream `istream` is implicitly cached within
  108. # the lexer, it will be replaced by a new input next time the lexer is re-used.
  109. _grammar_cache.data = lexer, token_stream, parser
  110. else:
  111. lexer, token_stream, parser = cached
  112. # Replace the old input stream with the new one.
  113. lexer.inputStream = istream
  114. # Initialize the lexer / token stream / parser to process the new input.
  115. lexer.mode(l_mode)
  116. token_stream.setTokenSource(lexer)
  117. parser.reset()
  118. try:
  119. return getattr(parser, parser_rule)()
  120. except Exception as exc:
  121. if type(exc) is Exception and str(exc) == "Empty Stack":
  122. # This exception is raised by antlr when trying to pop a mode while
  123. # no mode has been pushed. We convert it into an `GrammarParseError`
  124. # to facilitate exception handling from the caller.
  125. raise GrammarParseError("Empty Stack")
  126. else:
  127. raise