| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392 |
- import sys
- import warnings
- from itertools import zip_longest
- from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Dict,
- Generator,
- List,
- Optional,
- Set,
- Tuple,
- Union,
- )
- from antlr4 import TerminalNode
- from .errors import InterpolationResolutionError
- if TYPE_CHECKING:
- from .base import Node # noqa F401
- try:
- from omegaconf.grammar.gen.OmegaConfGrammarLexer import OmegaConfGrammarLexer
- from omegaconf.grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser
- from omegaconf.grammar.gen.OmegaConfGrammarParserVisitor import (
- OmegaConfGrammarParserVisitor,
- )
- except ModuleNotFoundError: # pragma: no cover
- print(
- "Error importing OmegaConf's generated parsers, run `python setup.py antlr` to regenerate.",
- file=sys.stderr,
- )
- sys.exit(1)
- class GrammarVisitor(OmegaConfGrammarParserVisitor):
- def __init__(
- self,
- node_interpolation_callback: Callable[
- [str, Optional[Set[int]]],
- Optional["Node"],
- ],
- resolver_interpolation_callback: Callable[..., Any],
- memo: Optional[Set[int]],
- **kw: Dict[Any, Any],
- ):
- """
- Constructor.
- :param node_interpolation_callback: Callback function that is called when
- needing to resolve a node interpolation. This function should take a single
- string input which is the key's dot path (ex: `"foo.bar"`).
- :param resolver_interpolation_callback: Callback function that is called when
- needing to resolve a resolver interpolation. This function should accept
- three keyword arguments: `name` (str, the name of the resolver),
- `args` (tuple, the inputs to the resolver), and `args_str` (tuple,
- the string representation of the inputs to the resolver).
- :param kw: Additional keyword arguments to be forwarded to parent class.
- """
- super().__init__(**kw)
- self.node_interpolation_callback = node_interpolation_callback
- self.resolver_interpolation_callback = resolver_interpolation_callback
- self.memo = memo
- def aggregateResult(self, aggregate: List[Any], nextResult: Any) -> List[Any]:
- raise NotImplementedError
- def defaultResult(self) -> List[Any]:
- # Raising an exception because not currently used (like `aggregateResult()`).
- raise NotImplementedError
- def visitConfigKey(self, ctx: OmegaConfGrammarParser.ConfigKeyContext) -> str:
- from ._utils import _get_value
- # interpolation | ID | INTER_KEY
- assert ctx.getChildCount() == 1
- child = ctx.getChild(0)
- if isinstance(child, OmegaConfGrammarParser.InterpolationContext):
- res = _get_value(self.visitInterpolation(child))
- if not isinstance(res, str):
- raise InterpolationResolutionError(
- f"The following interpolation is used to denote a config key and "
- f"thus should return a string, but instead returned `{res}` of "
- f"type `{type(res)}`: {ctx.getChild(0).getText()}"
- )
- return res
- else:
- assert isinstance(child, TerminalNode) and isinstance(
- child.symbol.text, str
- )
- return child.symbol.text
- def visitConfigValue(self, ctx: OmegaConfGrammarParser.ConfigValueContext) -> Any:
- # text EOF
- assert ctx.getChildCount() == 2
- return self.visit(ctx.getChild(0))
- def visitDictKey(self, ctx: OmegaConfGrammarParser.DictKeyContext) -> Any:
- return self._createPrimitive(ctx)
- def visitDictContainer(
- self, ctx: OmegaConfGrammarParser.DictContainerContext
- ) -> Dict[Any, Any]:
- # BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE
- assert ctx.getChildCount() >= 2
- return dict(
- self.visitDictKeyValuePair(ctx.getChild(i))
- for i in range(1, ctx.getChildCount() - 1, 2)
- )
- def visitElement(self, ctx: OmegaConfGrammarParser.ElementContext) -> Any:
- # primitive | quotedValue | listContainer | dictContainer
- assert ctx.getChildCount() == 1
- return self.visit(ctx.getChild(0))
- def visitInterpolation(
- self, ctx: OmegaConfGrammarParser.InterpolationContext
- ) -> Any:
- assert ctx.getChildCount() == 1 # interpolationNode | interpolationResolver
- return self.visit(ctx.getChild(0))
- def visitInterpolationNode(
- self, ctx: OmegaConfGrammarParser.InterpolationNodeContext
- ) -> Optional["Node"]:
- # INTER_OPEN
- # DOT* // relative interpolation?
- # (configKey | BRACKET_OPEN configKey BRACKET_CLOSE) // foo, [foo]
- # (DOT configKey | BRACKET_OPEN configKey BRACKET_CLOSE)* // .foo, [foo], .foo[bar], [foo].bar[baz]
- # INTER_CLOSE;
- assert ctx.getChildCount() >= 3
- inter_key_tokens = [] # parsed elements of the dot path
- for child in ctx.getChildren():
- if isinstance(child, TerminalNode):
- s = child.symbol
- if s.type in [
- OmegaConfGrammarLexer.DOT,
- OmegaConfGrammarLexer.BRACKET_OPEN,
- OmegaConfGrammarLexer.BRACKET_CLOSE,
- ]:
- inter_key_tokens.append(s.text)
- else:
- assert s.type in (
- OmegaConfGrammarLexer.INTER_OPEN,
- OmegaConfGrammarLexer.INTER_CLOSE,
- )
- else:
- assert isinstance(child, OmegaConfGrammarParser.ConfigKeyContext)
- inter_key_tokens.append(self.visitConfigKey(child))
- inter_key = "".join(inter_key_tokens)
- return self.node_interpolation_callback(inter_key, self.memo)
- def visitInterpolationResolver(
- self, ctx: OmegaConfGrammarParser.InterpolationResolverContext
- ) -> Any:
- # INTER_OPEN resolverName COLON sequence? BRACE_CLOSE
- assert 4 <= ctx.getChildCount() <= 5
- resolver_name = self.visit(ctx.getChild(1))
- maybe_seq = ctx.getChild(3)
- args = []
- args_str = []
- if isinstance(maybe_seq, TerminalNode): # means there are no args
- assert maybe_seq.symbol.type == OmegaConfGrammarLexer.BRACE_CLOSE
- else:
- assert isinstance(maybe_seq, OmegaConfGrammarParser.SequenceContext)
- for val, txt in self.visitSequence(maybe_seq):
- args.append(val)
- args_str.append(txt)
- return self.resolver_interpolation_callback(
- name=resolver_name,
- args=tuple(args),
- args_str=tuple(args_str),
- )
- def visitDictKeyValuePair(
- self, ctx: OmegaConfGrammarParser.DictKeyValuePairContext
- ) -> Tuple[Any, Any]:
- from ._utils import _get_value
- assert ctx.getChildCount() == 3 # dictKey COLON element
- key = self.visit(ctx.getChild(0))
- colon = ctx.getChild(1)
- assert (
- isinstance(colon, TerminalNode)
- and colon.symbol.type == OmegaConfGrammarLexer.COLON
- )
- value = _get_value(self.visitElement(ctx.getChild(2)))
- return key, value
- def visitListContainer(
- self, ctx: OmegaConfGrammarParser.ListContainerContext
- ) -> List[Any]:
- # BRACKET_OPEN sequence? BRACKET_CLOSE;
- assert ctx.getChildCount() in (2, 3)
- if ctx.getChildCount() == 2:
- return []
- sequence = ctx.getChild(1)
- assert isinstance(sequence, OmegaConfGrammarParser.SequenceContext)
- return list(val for val, _ in self.visitSequence(sequence)) # ignore raw text
- def visitPrimitive(self, ctx: OmegaConfGrammarParser.PrimitiveContext) -> Any:
- return self._createPrimitive(ctx)
- def visitQuotedValue(self, ctx: OmegaConfGrammarParser.QuotedValueContext) -> str:
- # (QUOTE_OPEN_SINGLE | QUOTE_OPEN_DOUBLE) text? MATCHING_QUOTE_CLOSE
- n = ctx.getChildCount()
- assert n in [2, 3]
- return str(self.visit(ctx.getChild(1))) if n == 3 else ""
- def visitResolverName(self, ctx: OmegaConfGrammarParser.ResolverNameContext) -> str:
- from ._utils import _get_value
- # (interpolation | ID) (DOT (interpolation | ID))*
- assert ctx.getChildCount() >= 1
- items = []
- for child in list(ctx.getChildren())[::2]:
- if isinstance(child, TerminalNode):
- assert child.symbol.type == OmegaConfGrammarLexer.ID
- items.append(child.symbol.text)
- else:
- assert isinstance(child, OmegaConfGrammarParser.InterpolationContext)
- item = _get_value(self.visitInterpolation(child))
- if not isinstance(item, str):
- raise InterpolationResolutionError(
- f"The name of a resolver must be a string, but the interpolation "
- f"{child.getText()} resolved to `{item}` which is of type "
- f"{type(item)}"
- )
- items.append(item)
- return ".".join(items)
- def visitSequence(
- self, ctx: OmegaConfGrammarParser.SequenceContext
- ) -> Generator[Any, None, None]:
- from ._utils import _get_value
- # (element (COMMA element?)*) | (COMMA element?)+
- assert ctx.getChildCount() >= 1
- # DEPRECATED: remove in 2.2 (revert #571)
- def empty_str_warning() -> None:
- txt = ctx.getText()
- warnings.warn(
- f"In the sequence `{txt}` some elements are missing: please replace "
- f"them with empty quoted strings. "
- f"See https://github.com/omry/omegaconf/issues/572 for details.",
- category=UserWarning,
- )
- is_previous_comma = True # whether previous child was a comma (init to True)
- for child in ctx.getChildren():
- if isinstance(child, OmegaConfGrammarParser.ElementContext):
- # Also preserve the original text representation of `child` so
- # as to allow backward compatibility with old resolvers (registered
- # with `legacy_register_resolver()`). Note that we cannot just cast
- # the value to string later as for instance `null` would become "None".
- yield _get_value(self.visitElement(child)), child.getText()
- is_previous_comma = False
- else:
- assert (
- isinstance(child, TerminalNode)
- and child.symbol.type == OmegaConfGrammarLexer.COMMA
- )
- if is_previous_comma:
- empty_str_warning()
- yield "", ""
- else:
- is_previous_comma = True
- if is_previous_comma:
- # Trailing comma.
- empty_str_warning()
- yield "", ""
- def visitSingleElement(
- self, ctx: OmegaConfGrammarParser.SingleElementContext
- ) -> Any:
- # element EOF
- assert ctx.getChildCount() == 2
- return self.visit(ctx.getChild(0))
- def visitText(self, ctx: OmegaConfGrammarParser.TextContext) -> Any:
- # (interpolation | ANY_STR | ESC | ESC_INTER | TOP_ESC | QUOTED_ESC)+
- # Single interpolation? If yes, return its resolved value "as is".
- if ctx.getChildCount() == 1:
- c = ctx.getChild(0)
- if isinstance(c, OmegaConfGrammarParser.InterpolationContext):
- return self.visitInterpolation(c)
- # Otherwise, concatenate string representations together.
- return self._unescape(list(ctx.getChildren()))
- def _createPrimitive(
- self,
- ctx: Union[
- OmegaConfGrammarParser.PrimitiveContext,
- OmegaConfGrammarParser.DictKeyContext,
- ],
- ) -> Any:
- # (ID | NULL | INT | FLOAT | BOOL | UNQUOTED_CHAR | COLON | ESC | WS | interpolation)+
- if ctx.getChildCount() == 1:
- child = ctx.getChild(0)
- if isinstance(child, OmegaConfGrammarParser.InterpolationContext):
- return self.visitInterpolation(child)
- assert isinstance(child, TerminalNode)
- symbol = child.symbol
- # Parse primitive types.
- if symbol.type in (
- OmegaConfGrammarLexer.ID,
- OmegaConfGrammarLexer.UNQUOTED_CHAR,
- OmegaConfGrammarLexer.COLON,
- ):
- return symbol.text
- elif symbol.type == OmegaConfGrammarLexer.NULL:
- return None
- elif symbol.type == OmegaConfGrammarLexer.INT:
- return int(symbol.text)
- elif symbol.type == OmegaConfGrammarLexer.FLOAT:
- return float(symbol.text)
- elif symbol.type == OmegaConfGrammarLexer.BOOL:
- return symbol.text.lower() == "true"
- elif symbol.type == OmegaConfGrammarLexer.ESC:
- return self._unescape([child])
- elif symbol.type == OmegaConfGrammarLexer.WS: # pragma: no cover
- # A single WS should have been "consumed" by another token.
- raise AssertionError("WS should never be reached")
- assert False, symbol.type
- # Concatenation of multiple items ==> un-escape the concatenation.
- return self._unescape(list(ctx.getChildren()))
- def _unescape(
- self,
- seq: List[Union[TerminalNode, OmegaConfGrammarParser.InterpolationContext]],
- ) -> str:
- """
- Concatenate all symbols / interpolations in `seq`, unescaping symbols as needed.
- Interpolations are resolved and cast to string *WITHOUT* escaping their result
- (it is assumed that whatever escaping is required was already handled during the
- resolving of the interpolation).
- """
- chrs = []
- for node, next_node in zip_longest(seq, seq[1:]):
- if isinstance(node, TerminalNode):
- s = node.symbol
- if s.type == OmegaConfGrammarLexer.ESC_INTER:
- # `ESC_INTER` is of the form `\\...\${`: the formula below computes
- # the number of characters to keep at the end of the string to remove
- # the correct number of backslashes.
- text = s.text[-(len(s.text) // 2 + 1) :]
- elif (
- # Character sequence identified as requiring un-escaping.
- s.type == OmegaConfGrammarLexer.ESC
- or (
- # At top level, we need to un-escape backslashes that precede
- # an interpolation.
- s.type == OmegaConfGrammarLexer.TOP_ESC
- and isinstance(
- next_node, OmegaConfGrammarParser.InterpolationContext
- )
- )
- or (
- # In a quoted sring, we need to un-escape backslashes that
- # either end the string, or are followed by an interpolation.
- s.type == OmegaConfGrammarLexer.QUOTED_ESC
- and (
- next_node is None
- or isinstance(
- next_node, OmegaConfGrammarParser.InterpolationContext
- )
- )
- )
- ):
- text = s.text[1::2] # un-escape the sequence
- else:
- text = s.text # keep the original text
- else:
- assert isinstance(node, OmegaConfGrammarParser.InterpolationContext)
- text = str(self.visitInterpolation(node))
- chrs.append(text)
- return "".join(chrs)
|