grammar_visitor.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. import sys
  2. import warnings
  3. from itertools import zip_longest
  4. from typing import (
  5. TYPE_CHECKING,
  6. Any,
  7. Callable,
  8. Dict,
  9. Generator,
  10. List,
  11. Optional,
  12. Set,
  13. Tuple,
  14. Union,
  15. )
  16. from antlr4 import TerminalNode
  17. from .errors import InterpolationResolutionError
  18. if TYPE_CHECKING:
  19. from .base import Node # noqa F401
  20. try:
  21. from omegaconf.grammar.gen.OmegaConfGrammarLexer import OmegaConfGrammarLexer
  22. from omegaconf.grammar.gen.OmegaConfGrammarParser import OmegaConfGrammarParser
  23. from omegaconf.grammar.gen.OmegaConfGrammarParserVisitor import (
  24. OmegaConfGrammarParserVisitor,
  25. )
  26. except ModuleNotFoundError: # pragma: no cover
  27. print(
  28. "Error importing OmegaConf's generated parsers, run `python setup.py antlr` to regenerate.",
  29. file=sys.stderr,
  30. )
  31. sys.exit(1)
  32. class GrammarVisitor(OmegaConfGrammarParserVisitor):
  33. def __init__(
  34. self,
  35. node_interpolation_callback: Callable[
  36. [str, Optional[Set[int]]],
  37. Optional["Node"],
  38. ],
  39. resolver_interpolation_callback: Callable[..., Any],
  40. memo: Optional[Set[int]],
  41. **kw: Dict[Any, Any],
  42. ):
  43. """
  44. Constructor.
  45. :param node_interpolation_callback: Callback function that is called when
  46. needing to resolve a node interpolation. This function should take a single
  47. string input which is the key's dot path (ex: `"foo.bar"`).
  48. :param resolver_interpolation_callback: Callback function that is called when
  49. needing to resolve a resolver interpolation. This function should accept
  50. three keyword arguments: `name` (str, the name of the resolver),
  51. `args` (tuple, the inputs to the resolver), and `args_str` (tuple,
  52. the string representation of the inputs to the resolver).
  53. :param kw: Additional keyword arguments to be forwarded to parent class.
  54. """
  55. super().__init__(**kw)
  56. self.node_interpolation_callback = node_interpolation_callback
  57. self.resolver_interpolation_callback = resolver_interpolation_callback
  58. self.memo = memo
  59. def aggregateResult(self, aggregate: List[Any], nextResult: Any) -> List[Any]:
  60. raise NotImplementedError
  61. def defaultResult(self) -> List[Any]:
  62. # Raising an exception because not currently used (like `aggregateResult()`).
  63. raise NotImplementedError
  64. def visitConfigKey(self, ctx: OmegaConfGrammarParser.ConfigKeyContext) -> str:
  65. from ._utils import _get_value
  66. # interpolation | ID | INTER_KEY
  67. assert ctx.getChildCount() == 1
  68. child = ctx.getChild(0)
  69. if isinstance(child, OmegaConfGrammarParser.InterpolationContext):
  70. res = _get_value(self.visitInterpolation(child))
  71. if not isinstance(res, str):
  72. raise InterpolationResolutionError(
  73. f"The following interpolation is used to denote a config key and "
  74. f"thus should return a string, but instead returned `{res}` of "
  75. f"type `{type(res)}`: {ctx.getChild(0).getText()}"
  76. )
  77. return res
  78. else:
  79. assert isinstance(child, TerminalNode) and isinstance(
  80. child.symbol.text, str
  81. )
  82. return child.symbol.text
  83. def visitConfigValue(self, ctx: OmegaConfGrammarParser.ConfigValueContext) -> Any:
  84. # text EOF
  85. assert ctx.getChildCount() == 2
  86. return self.visit(ctx.getChild(0))
  87. def visitDictKey(self, ctx: OmegaConfGrammarParser.DictKeyContext) -> Any:
  88. return self._createPrimitive(ctx)
  89. def visitDictContainer(
  90. self, ctx: OmegaConfGrammarParser.DictContainerContext
  91. ) -> Dict[Any, Any]:
  92. # BRACE_OPEN (dictKeyValuePair (COMMA dictKeyValuePair)*)? BRACE_CLOSE
  93. assert ctx.getChildCount() >= 2
  94. return dict(
  95. self.visitDictKeyValuePair(ctx.getChild(i))
  96. for i in range(1, ctx.getChildCount() - 1, 2)
  97. )
  98. def visitElement(self, ctx: OmegaConfGrammarParser.ElementContext) -> Any:
  99. # primitive | quotedValue | listContainer | dictContainer
  100. assert ctx.getChildCount() == 1
  101. return self.visit(ctx.getChild(0))
  102. def visitInterpolation(
  103. self, ctx: OmegaConfGrammarParser.InterpolationContext
  104. ) -> Any:
  105. assert ctx.getChildCount() == 1 # interpolationNode | interpolationResolver
  106. return self.visit(ctx.getChild(0))
  107. def visitInterpolationNode(
  108. self, ctx: OmegaConfGrammarParser.InterpolationNodeContext
  109. ) -> Optional["Node"]:
  110. # INTER_OPEN
  111. # DOT* // relative interpolation?
  112. # (configKey | BRACKET_OPEN configKey BRACKET_CLOSE) // foo, [foo]
  113. # (DOT configKey | BRACKET_OPEN configKey BRACKET_CLOSE)* // .foo, [foo], .foo[bar], [foo].bar[baz]
  114. # INTER_CLOSE;
  115. assert ctx.getChildCount() >= 3
  116. inter_key_tokens = [] # parsed elements of the dot path
  117. for child in ctx.getChildren():
  118. if isinstance(child, TerminalNode):
  119. s = child.symbol
  120. if s.type in [
  121. OmegaConfGrammarLexer.DOT,
  122. OmegaConfGrammarLexer.BRACKET_OPEN,
  123. OmegaConfGrammarLexer.BRACKET_CLOSE,
  124. ]:
  125. inter_key_tokens.append(s.text)
  126. else:
  127. assert s.type in (
  128. OmegaConfGrammarLexer.INTER_OPEN,
  129. OmegaConfGrammarLexer.INTER_CLOSE,
  130. )
  131. else:
  132. assert isinstance(child, OmegaConfGrammarParser.ConfigKeyContext)
  133. inter_key_tokens.append(self.visitConfigKey(child))
  134. inter_key = "".join(inter_key_tokens)
  135. return self.node_interpolation_callback(inter_key, self.memo)
  136. def visitInterpolationResolver(
  137. self, ctx: OmegaConfGrammarParser.InterpolationResolverContext
  138. ) -> Any:
  139. # INTER_OPEN resolverName COLON sequence? BRACE_CLOSE
  140. assert 4 <= ctx.getChildCount() <= 5
  141. resolver_name = self.visit(ctx.getChild(1))
  142. maybe_seq = ctx.getChild(3)
  143. args = []
  144. args_str = []
  145. if isinstance(maybe_seq, TerminalNode): # means there are no args
  146. assert maybe_seq.symbol.type == OmegaConfGrammarLexer.BRACE_CLOSE
  147. else:
  148. assert isinstance(maybe_seq, OmegaConfGrammarParser.SequenceContext)
  149. for val, txt in self.visitSequence(maybe_seq):
  150. args.append(val)
  151. args_str.append(txt)
  152. return self.resolver_interpolation_callback(
  153. name=resolver_name,
  154. args=tuple(args),
  155. args_str=tuple(args_str),
  156. )
  157. def visitDictKeyValuePair(
  158. self, ctx: OmegaConfGrammarParser.DictKeyValuePairContext
  159. ) -> Tuple[Any, Any]:
  160. from ._utils import _get_value
  161. assert ctx.getChildCount() == 3 # dictKey COLON element
  162. key = self.visit(ctx.getChild(0))
  163. colon = ctx.getChild(1)
  164. assert (
  165. isinstance(colon, TerminalNode)
  166. and colon.symbol.type == OmegaConfGrammarLexer.COLON
  167. )
  168. value = _get_value(self.visitElement(ctx.getChild(2)))
  169. return key, value
  170. def visitListContainer(
  171. self, ctx: OmegaConfGrammarParser.ListContainerContext
  172. ) -> List[Any]:
  173. # BRACKET_OPEN sequence? BRACKET_CLOSE;
  174. assert ctx.getChildCount() in (2, 3)
  175. if ctx.getChildCount() == 2:
  176. return []
  177. sequence = ctx.getChild(1)
  178. assert isinstance(sequence, OmegaConfGrammarParser.SequenceContext)
  179. return list(val for val, _ in self.visitSequence(sequence)) # ignore raw text
  180. def visitPrimitive(self, ctx: OmegaConfGrammarParser.PrimitiveContext) -> Any:
  181. return self._createPrimitive(ctx)
  182. def visitQuotedValue(self, ctx: OmegaConfGrammarParser.QuotedValueContext) -> str:
  183. # (QUOTE_OPEN_SINGLE | QUOTE_OPEN_DOUBLE) text? MATCHING_QUOTE_CLOSE
  184. n = ctx.getChildCount()
  185. assert n in [2, 3]
  186. return str(self.visit(ctx.getChild(1))) if n == 3 else ""
  187. def visitResolverName(self, ctx: OmegaConfGrammarParser.ResolverNameContext) -> str:
  188. from ._utils import _get_value
  189. # (interpolation | ID) (DOT (interpolation | ID))*
  190. assert ctx.getChildCount() >= 1
  191. items = []
  192. for child in list(ctx.getChildren())[::2]:
  193. if isinstance(child, TerminalNode):
  194. assert child.symbol.type == OmegaConfGrammarLexer.ID
  195. items.append(child.symbol.text)
  196. else:
  197. assert isinstance(child, OmegaConfGrammarParser.InterpolationContext)
  198. item = _get_value(self.visitInterpolation(child))
  199. if not isinstance(item, str):
  200. raise InterpolationResolutionError(
  201. f"The name of a resolver must be a string, but the interpolation "
  202. f"{child.getText()} resolved to `{item}` which is of type "
  203. f"{type(item)}"
  204. )
  205. items.append(item)
  206. return ".".join(items)
  207. def visitSequence(
  208. self, ctx: OmegaConfGrammarParser.SequenceContext
  209. ) -> Generator[Any, None, None]:
  210. from ._utils import _get_value
  211. # (element (COMMA element?)*) | (COMMA element?)+
  212. assert ctx.getChildCount() >= 1
  213. # DEPRECATED: remove in 2.2 (revert #571)
  214. def empty_str_warning() -> None:
  215. txt = ctx.getText()
  216. warnings.warn(
  217. f"In the sequence `{txt}` some elements are missing: please replace "
  218. f"them with empty quoted strings. "
  219. f"See https://github.com/omry/omegaconf/issues/572 for details.",
  220. category=UserWarning,
  221. )
  222. is_previous_comma = True # whether previous child was a comma (init to True)
  223. for child in ctx.getChildren():
  224. if isinstance(child, OmegaConfGrammarParser.ElementContext):
  225. # Also preserve the original text representation of `child` so
  226. # as to allow backward compatibility with old resolvers (registered
  227. # with `legacy_register_resolver()`). Note that we cannot just cast
  228. # the value to string later as for instance `null` would become "None".
  229. yield _get_value(self.visitElement(child)), child.getText()
  230. is_previous_comma = False
  231. else:
  232. assert (
  233. isinstance(child, TerminalNode)
  234. and child.symbol.type == OmegaConfGrammarLexer.COMMA
  235. )
  236. if is_previous_comma:
  237. empty_str_warning()
  238. yield "", ""
  239. else:
  240. is_previous_comma = True
  241. if is_previous_comma:
  242. # Trailing comma.
  243. empty_str_warning()
  244. yield "", ""
  245. def visitSingleElement(
  246. self, ctx: OmegaConfGrammarParser.SingleElementContext
  247. ) -> Any:
  248. # element EOF
  249. assert ctx.getChildCount() == 2
  250. return self.visit(ctx.getChild(0))
  251. def visitText(self, ctx: OmegaConfGrammarParser.TextContext) -> Any:
  252. # (interpolation | ANY_STR | ESC | ESC_INTER | TOP_ESC | QUOTED_ESC)+
  253. # Single interpolation? If yes, return its resolved value "as is".
  254. if ctx.getChildCount() == 1:
  255. c = ctx.getChild(0)
  256. if isinstance(c, OmegaConfGrammarParser.InterpolationContext):
  257. return self.visitInterpolation(c)
  258. # Otherwise, concatenate string representations together.
  259. return self._unescape(list(ctx.getChildren()))
  260. def _createPrimitive(
  261. self,
  262. ctx: Union[
  263. OmegaConfGrammarParser.PrimitiveContext,
  264. OmegaConfGrammarParser.DictKeyContext,
  265. ],
  266. ) -> Any:
  267. # (ID | NULL | INT | FLOAT | BOOL | UNQUOTED_CHAR | COLON | ESC | WS | interpolation)+
  268. if ctx.getChildCount() == 1:
  269. child = ctx.getChild(0)
  270. if isinstance(child, OmegaConfGrammarParser.InterpolationContext):
  271. return self.visitInterpolation(child)
  272. assert isinstance(child, TerminalNode)
  273. symbol = child.symbol
  274. # Parse primitive types.
  275. if symbol.type in (
  276. OmegaConfGrammarLexer.ID,
  277. OmegaConfGrammarLexer.UNQUOTED_CHAR,
  278. OmegaConfGrammarLexer.COLON,
  279. ):
  280. return symbol.text
  281. elif symbol.type == OmegaConfGrammarLexer.NULL:
  282. return None
  283. elif symbol.type == OmegaConfGrammarLexer.INT:
  284. return int(symbol.text)
  285. elif symbol.type == OmegaConfGrammarLexer.FLOAT:
  286. return float(symbol.text)
  287. elif symbol.type == OmegaConfGrammarLexer.BOOL:
  288. return symbol.text.lower() == "true"
  289. elif symbol.type == OmegaConfGrammarLexer.ESC:
  290. return self._unescape([child])
  291. elif symbol.type == OmegaConfGrammarLexer.WS: # pragma: no cover
  292. # A single WS should have been "consumed" by another token.
  293. raise AssertionError("WS should never be reached")
  294. assert False, symbol.type
  295. # Concatenation of multiple items ==> un-escape the concatenation.
  296. return self._unescape(list(ctx.getChildren()))
  297. def _unescape(
  298. self,
  299. seq: List[Union[TerminalNode, OmegaConfGrammarParser.InterpolationContext]],
  300. ) -> str:
  301. """
  302. Concatenate all symbols / interpolations in `seq`, unescaping symbols as needed.
  303. Interpolations are resolved and cast to string *WITHOUT* escaping their result
  304. (it is assumed that whatever escaping is required was already handled during the
  305. resolving of the interpolation).
  306. """
  307. chrs = []
  308. for node, next_node in zip_longest(seq, seq[1:]):
  309. if isinstance(node, TerminalNode):
  310. s = node.symbol
  311. if s.type == OmegaConfGrammarLexer.ESC_INTER:
  312. # `ESC_INTER` is of the form `\\...\${`: the formula below computes
  313. # the number of characters to keep at the end of the string to remove
  314. # the correct number of backslashes.
  315. text = s.text[-(len(s.text) // 2 + 1) :]
  316. elif (
  317. # Character sequence identified as requiring un-escaping.
  318. s.type == OmegaConfGrammarLexer.ESC
  319. or (
  320. # At top level, we need to un-escape backslashes that precede
  321. # an interpolation.
  322. s.type == OmegaConfGrammarLexer.TOP_ESC
  323. and isinstance(
  324. next_node, OmegaConfGrammarParser.InterpolationContext
  325. )
  326. )
  327. or (
  328. # In a quoted sring, we need to un-escape backslashes that
  329. # either end the string, or are followed by an interpolation.
  330. s.type == OmegaConfGrammarLexer.QUOTED_ESC
  331. and (
  332. next_node is None
  333. or isinstance(
  334. next_node, OmegaConfGrammarParser.InterpolationContext
  335. )
  336. )
  337. )
  338. ):
  339. text = s.text[1::2] # un-escape the sequence
  340. else:
  341. text = s.text # keep the original text
  342. else:
  343. assert isinstance(node, OmegaConfGrammarParser.InterpolationContext)
  344. text = str(self.visitInterpolation(node))
  345. chrs.append(text)
  346. return "".join(chrs)