ATNDeserializer.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529
  1. # Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
  2. # Use of this file is governed by the BSD 3-clause license that
  3. # can be found in the LICENSE.txt file in the project root.
  4. #/
  5. from uuid import UUID
  6. from io import StringIO
  7. from typing import Callable
  8. from antlr4.Token import Token
  9. from antlr4.atn.ATN import ATN
  10. from antlr4.atn.ATNType import ATNType
  11. from antlr4.atn.ATNState import *
  12. from antlr4.atn.Transition import *
  13. from antlr4.atn.LexerAction import *
  14. from antlr4.atn.ATNDeserializationOptions import ATNDeserializationOptions
  15. # This is the earliest supported serialized UUID.
  16. BASE_SERIALIZED_UUID = UUID("AADB8D7E-AEEF-4415-AD2B-8204D6CF042E")
  17. # This UUID indicates the serialized ATN contains two sets of
  18. # IntervalSets, where the second set's values are encoded as
  19. # 32-bit integers to support the full Unicode SMP range up to U+10FFFF.
  20. ADDED_UNICODE_SMP = UUID("59627784-3BE5-417A-B9EB-8131A7286089")
  21. # This list contains all of the currently supported UUIDs, ordered by when
  22. # the feature first appeared in this branch.
  23. SUPPORTED_UUIDS = [ BASE_SERIALIZED_UUID, ADDED_UNICODE_SMP ]
  24. SERIALIZED_VERSION = 3
  25. # This is the current serialized UUID.
  26. SERIALIZED_UUID = ADDED_UNICODE_SMP
  27. class ATNDeserializer (object):
  28. __slots__ = ('deserializationOptions', 'data', 'pos', 'uuid')
  29. def __init__(self, options : ATNDeserializationOptions = None):
  30. if options is None:
  31. options = ATNDeserializationOptions.defaultOptions
  32. self.deserializationOptions = options
  33. # Determines if a particular serialized representation of an ATN supports
  34. # a particular feature, identified by the {@link UUID} used for serializing
  35. # the ATN at the time the feature was first introduced.
  36. #
  37. # @param feature The {@link UUID} marking the first time the feature was
  38. # supported in the serialized ATN.
  39. # @param actualUuid The {@link UUID} of the actual serialized ATN which is
  40. # currently being deserialized.
  41. # @return {@code true} if the {@code actualUuid} value represents a
  42. # serialized ATN at or after the feature identified by {@code feature} was
  43. # introduced; otherwise, {@code false}.
  44. def isFeatureSupported(self, feature : UUID , actualUuid : UUID ):
  45. idx1 = SUPPORTED_UUIDS.index(feature)
  46. if idx1<0:
  47. return False
  48. idx2 = SUPPORTED_UUIDS.index(actualUuid)
  49. return idx2 >= idx1
  50. def deserialize(self, data : str):
  51. self.reset(data)
  52. self.checkVersion()
  53. self.checkUUID()
  54. atn = self.readATN()
  55. self.readStates(atn)
  56. self.readRules(atn)
  57. self.readModes(atn)
  58. sets = []
  59. # First, read all sets with 16-bit Unicode code points <= U+FFFF.
  60. self.readSets(atn, sets, self.readInt)
  61. # Next, if the ATN was serialized with the Unicode SMP feature,
  62. # deserialize sets with 32-bit arguments <= U+10FFFF.
  63. if self.isFeatureSupported(ADDED_UNICODE_SMP, self.uuid):
  64. self.readSets(atn, sets, self.readInt32)
  65. self.readEdges(atn, sets)
  66. self.readDecisions(atn)
  67. self.readLexerActions(atn)
  68. self.markPrecedenceDecisions(atn)
  69. self.verifyATN(atn)
  70. if self.deserializationOptions.generateRuleBypassTransitions \
  71. and atn.grammarType == ATNType.PARSER:
  72. self.generateRuleBypassTransitions(atn)
  73. # re-verify after modification
  74. self.verifyATN(atn)
  75. return atn
  76. def reset(self, data:str):
  77. def adjust(c):
  78. v = ord(c)
  79. return v-2 if v>1 else v + 65533
  80. temp = [ adjust(c) for c in data ]
  81. # don't adjust the first value since that's the version number
  82. temp[0] = ord(data[0])
  83. self.data = temp
  84. self.pos = 0
  85. def checkVersion(self):
  86. version = self.readInt()
  87. if version != SERIALIZED_VERSION:
  88. raise Exception("Could not deserialize ATN with version " + str(version) + " (expected " + str(SERIALIZED_VERSION) + ").")
  89. def checkUUID(self):
  90. uuid = self.readUUID()
  91. if not uuid in SUPPORTED_UUIDS:
  92. raise Exception("Could not deserialize ATN with UUID: " + str(uuid) + \
  93. " (expected " + str(SERIALIZED_UUID) + " or a legacy UUID).", uuid, SERIALIZED_UUID)
  94. self.uuid = uuid
  95. def readATN(self):
  96. idx = self.readInt()
  97. grammarType = ATNType.fromOrdinal(idx)
  98. maxTokenType = self.readInt()
  99. return ATN(grammarType, maxTokenType)
  100. def readStates(self, atn:ATN):
  101. loopBackStateNumbers = []
  102. endStateNumbers = []
  103. nstates = self.readInt()
  104. for i in range(0, nstates):
  105. stype = self.readInt()
  106. # ignore bad type of states
  107. if stype==ATNState.INVALID_TYPE:
  108. atn.addState(None)
  109. continue
  110. ruleIndex = self.readInt()
  111. if ruleIndex == 0xFFFF:
  112. ruleIndex = -1
  113. s = self.stateFactory(stype, ruleIndex)
  114. if stype == ATNState.LOOP_END: # special case
  115. loopBackStateNumber = self.readInt()
  116. loopBackStateNumbers.append((s, loopBackStateNumber))
  117. elif isinstance(s, BlockStartState):
  118. endStateNumber = self.readInt()
  119. endStateNumbers.append((s, endStateNumber))
  120. atn.addState(s)
  121. # delay the assignment of loop back and end states until we know all the state instances have been initialized
  122. for pair in loopBackStateNumbers:
  123. pair[0].loopBackState = atn.states[pair[1]]
  124. for pair in endStateNumbers:
  125. pair[0].endState = atn.states[pair[1]]
  126. numNonGreedyStates = self.readInt()
  127. for i in range(0, numNonGreedyStates):
  128. stateNumber = self.readInt()
  129. atn.states[stateNumber].nonGreedy = True
  130. numPrecedenceStates = self.readInt()
  131. for i in range(0, numPrecedenceStates):
  132. stateNumber = self.readInt()
  133. atn.states[stateNumber].isPrecedenceRule = True
  134. def readRules(self, atn:ATN):
  135. nrules = self.readInt()
  136. if atn.grammarType == ATNType.LEXER:
  137. atn.ruleToTokenType = [0] * nrules
  138. atn.ruleToStartState = [0] * nrules
  139. for i in range(0, nrules):
  140. s = self.readInt()
  141. startState = atn.states[s]
  142. atn.ruleToStartState[i] = startState
  143. if atn.grammarType == ATNType.LEXER:
  144. tokenType = self.readInt()
  145. if tokenType == 0xFFFF:
  146. tokenType = Token.EOF
  147. atn.ruleToTokenType[i] = tokenType
  148. atn.ruleToStopState = [0] * nrules
  149. for state in atn.states:
  150. if not isinstance(state, RuleStopState):
  151. continue
  152. atn.ruleToStopState[state.ruleIndex] = state
  153. atn.ruleToStartState[state.ruleIndex].stopState = state
  154. def readModes(self, atn:ATN):
  155. nmodes = self.readInt()
  156. for i in range(0, nmodes):
  157. s = self.readInt()
  158. atn.modeToStartState.append(atn.states[s])
  159. def readSets(self, atn:ATN, sets:list, readUnicode:Callable[[], int]):
  160. m = self.readInt()
  161. for i in range(0, m):
  162. iset = IntervalSet()
  163. sets.append(iset)
  164. n = self.readInt()
  165. containsEof = self.readInt()
  166. if containsEof!=0:
  167. iset.addOne(-1)
  168. for j in range(0, n):
  169. i1 = readUnicode()
  170. i2 = readUnicode()
  171. iset.addRange(range(i1, i2 + 1)) # range upper limit is exclusive
  172. def readEdges(self, atn:ATN, sets:list):
  173. nedges = self.readInt()
  174. for i in range(0, nedges):
  175. src = self.readInt()
  176. trg = self.readInt()
  177. ttype = self.readInt()
  178. arg1 = self.readInt()
  179. arg2 = self.readInt()
  180. arg3 = self.readInt()
  181. trans = self.edgeFactory(atn, ttype, src, trg, arg1, arg2, arg3, sets)
  182. srcState = atn.states[src]
  183. srcState.addTransition(trans)
  184. # edges for rule stop states can be derived, so they aren't serialized
  185. for state in atn.states:
  186. for i in range(0, len(state.transitions)):
  187. t = state.transitions[i]
  188. if not isinstance(t, RuleTransition):
  189. continue
  190. outermostPrecedenceReturn = -1
  191. if atn.ruleToStartState[t.target.ruleIndex].isPrecedenceRule:
  192. if t.precedence == 0:
  193. outermostPrecedenceReturn = t.target.ruleIndex
  194. trans = EpsilonTransition(t.followState, outermostPrecedenceReturn)
  195. atn.ruleToStopState[t.target.ruleIndex].addTransition(trans)
  196. for state in atn.states:
  197. if isinstance(state, BlockStartState):
  198. # we need to know the end state to set its start state
  199. if state.endState is None:
  200. raise Exception("IllegalState")
  201. # block end states can only be associated to a single block start state
  202. if state.endState.startState is not None:
  203. raise Exception("IllegalState")
  204. state.endState.startState = state
  205. if isinstance(state, PlusLoopbackState):
  206. for i in range(0, len(state.transitions)):
  207. target = state.transitions[i].target
  208. if isinstance(target, PlusBlockStartState):
  209. target.loopBackState = state
  210. elif isinstance(state, StarLoopbackState):
  211. for i in range(0, len(state.transitions)):
  212. target = state.transitions[i].target
  213. if isinstance(target, StarLoopEntryState):
  214. target.loopBackState = state
  215. def readDecisions(self, atn:ATN):
  216. ndecisions = self.readInt()
  217. for i in range(0, ndecisions):
  218. s = self.readInt()
  219. decState = atn.states[s]
  220. atn.decisionToState.append(decState)
  221. decState.decision = i
  222. def readLexerActions(self, atn:ATN):
  223. if atn.grammarType == ATNType.LEXER:
  224. count = self.readInt()
  225. atn.lexerActions = [ None ] * count
  226. for i in range(0, count):
  227. actionType = self.readInt()
  228. data1 = self.readInt()
  229. if data1 == 0xFFFF:
  230. data1 = -1
  231. data2 = self.readInt()
  232. if data2 == 0xFFFF:
  233. data2 = -1
  234. lexerAction = self.lexerActionFactory(actionType, data1, data2)
  235. atn.lexerActions[i] = lexerAction
  236. def generateRuleBypassTransitions(self, atn:ATN):
  237. count = len(atn.ruleToStartState)
  238. atn.ruleToTokenType = [ 0 ] * count
  239. for i in range(0, count):
  240. atn.ruleToTokenType[i] = atn.maxTokenType + i + 1
  241. for i in range(0, count):
  242. self.generateRuleBypassTransition(atn, i)
  243. def generateRuleBypassTransition(self, atn:ATN, idx:int):
  244. bypassStart = BasicBlockStartState()
  245. bypassStart.ruleIndex = idx
  246. atn.addState(bypassStart)
  247. bypassStop = BlockEndState()
  248. bypassStop.ruleIndex = idx
  249. atn.addState(bypassStop)
  250. bypassStart.endState = bypassStop
  251. atn.defineDecisionState(bypassStart)
  252. bypassStop.startState = bypassStart
  253. excludeTransition = None
  254. if atn.ruleToStartState[idx].isPrecedenceRule:
  255. # wrap from the beginning of the rule to the StarLoopEntryState
  256. endState = None
  257. for state in atn.states:
  258. if self.stateIsEndStateFor(state, idx):
  259. endState = state
  260. excludeTransition = state.loopBackState.transitions[0]
  261. break
  262. if excludeTransition is None:
  263. raise Exception("Couldn't identify final state of the precedence rule prefix section.")
  264. else:
  265. endState = atn.ruleToStopState[idx]
  266. # all non-excluded transitions that currently target end state need to target blockEnd instead
  267. for state in atn.states:
  268. for transition in state.transitions:
  269. if transition == excludeTransition:
  270. continue
  271. if transition.target == endState:
  272. transition.target = bypassStop
  273. # all transitions leaving the rule start state need to leave blockStart instead
  274. ruleToStartState = atn.ruleToStartState[idx]
  275. count = len(ruleToStartState.transitions)
  276. while count > 0:
  277. bypassStart.addTransition(ruleToStartState.transitions[count-1])
  278. del ruleToStartState.transitions[-1]
  279. # link the new states
  280. atn.ruleToStartState[idx].addTransition(EpsilonTransition(bypassStart))
  281. bypassStop.addTransition(EpsilonTransition(endState))
  282. matchState = BasicState()
  283. atn.addState(matchState)
  284. matchState.addTransition(AtomTransition(bypassStop, atn.ruleToTokenType[idx]))
  285. bypassStart.addTransition(EpsilonTransition(matchState))
  286. def stateIsEndStateFor(self, state:ATNState, idx:int):
  287. if state.ruleIndex != idx:
  288. return None
  289. if not isinstance(state, StarLoopEntryState):
  290. return None
  291. maybeLoopEndState = state.transitions[len(state.transitions) - 1].target
  292. if not isinstance(maybeLoopEndState, LoopEndState):
  293. return None
  294. if maybeLoopEndState.epsilonOnlyTransitions and \
  295. isinstance(maybeLoopEndState.transitions[0].target, RuleStopState):
  296. return state
  297. else:
  298. return None
  299. #
  300. # Analyze the {@link StarLoopEntryState} states in the specified ATN to set
  301. # the {@link StarLoopEntryState#isPrecedenceDecision} field to the
  302. # correct value.
  303. #
  304. # @param atn The ATN.
  305. #
  306. def markPrecedenceDecisions(self, atn:ATN):
  307. for state in atn.states:
  308. if not isinstance(state, StarLoopEntryState):
  309. continue
  310. # We analyze the ATN to determine if this ATN decision state is the
  311. # decision for the closure block that determines whether a
  312. # precedence rule should continue or complete.
  313. #
  314. if atn.ruleToStartState[state.ruleIndex].isPrecedenceRule:
  315. maybeLoopEndState = state.transitions[len(state.transitions) - 1].target
  316. if isinstance(maybeLoopEndState, LoopEndState):
  317. if maybeLoopEndState.epsilonOnlyTransitions and \
  318. isinstance(maybeLoopEndState.transitions[0].target, RuleStopState):
  319. state.isPrecedenceDecision = True
  320. def verifyATN(self, atn:ATN):
  321. if not self.deserializationOptions.verifyATN:
  322. return
  323. # verify assumptions
  324. for state in atn.states:
  325. if state is None:
  326. continue
  327. self.checkCondition(state.epsilonOnlyTransitions or len(state.transitions) <= 1)
  328. if isinstance(state, PlusBlockStartState):
  329. self.checkCondition(state.loopBackState is not None)
  330. if isinstance(state, StarLoopEntryState):
  331. self.checkCondition(state.loopBackState is not None)
  332. self.checkCondition(len(state.transitions) == 2)
  333. if isinstance(state.transitions[0].target, StarBlockStartState):
  334. self.checkCondition(isinstance(state.transitions[1].target, LoopEndState))
  335. self.checkCondition(not state.nonGreedy)
  336. elif isinstance(state.transitions[0].target, LoopEndState):
  337. self.checkCondition(isinstance(state.transitions[1].target, StarBlockStartState))
  338. self.checkCondition(state.nonGreedy)
  339. else:
  340. raise Exception("IllegalState")
  341. if isinstance(state, StarLoopbackState):
  342. self.checkCondition(len(state.transitions) == 1)
  343. self.checkCondition(isinstance(state.transitions[0].target, StarLoopEntryState))
  344. if isinstance(state, LoopEndState):
  345. self.checkCondition(state.loopBackState is not None)
  346. if isinstance(state, RuleStartState):
  347. self.checkCondition(state.stopState is not None)
  348. if isinstance(state, BlockStartState):
  349. self.checkCondition(state.endState is not None)
  350. if isinstance(state, BlockEndState):
  351. self.checkCondition(state.startState is not None)
  352. if isinstance(state, DecisionState):
  353. self.checkCondition(len(state.transitions) <= 1 or state.decision >= 0)
  354. else:
  355. self.checkCondition(len(state.transitions) <= 1 or isinstance(state, RuleStopState))
  356. def checkCondition(self, condition:bool, message=None):
  357. if not condition:
  358. if message is None:
  359. message = "IllegalState"
  360. raise Exception(message)
  361. def readInt(self):
  362. i = self.data[self.pos]
  363. self.pos += 1
  364. return i
  365. def readInt32(self):
  366. low = self.readInt()
  367. high = self.readInt()
  368. return low | (high << 16)
  369. def readLong(self):
  370. low = self.readInt32()
  371. high = self.readInt32()
  372. return (low & 0x00000000FFFFFFFF) | (high << 32)
  373. def readUUID(self):
  374. low = self.readLong()
  375. high = self.readLong()
  376. allBits = (low & 0xFFFFFFFFFFFFFFFF) | (high << 64)
  377. return UUID(int=allBits)
  378. edgeFactories = [ lambda args : None,
  379. lambda atn, src, trg, arg1, arg2, arg3, sets, target : EpsilonTransition(target),
  380. lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
  381. RangeTransition(target, Token.EOF, arg2) if arg3 != 0 else RangeTransition(target, arg1, arg2),
  382. lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
  383. RuleTransition(atn.states[arg1], arg2, arg3, target),
  384. lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
  385. PredicateTransition(target, arg1, arg2, arg3 != 0),
  386. lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
  387. AtomTransition(target, Token.EOF) if arg3 != 0 else AtomTransition(target, arg1),
  388. lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
  389. ActionTransition(target, arg1, arg2, arg3 != 0),
  390. lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
  391. SetTransition(target, sets[arg1]),
  392. lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
  393. NotSetTransition(target, sets[arg1]),
  394. lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
  395. WildcardTransition(target),
  396. lambda atn, src, trg, arg1, arg2, arg3, sets, target : \
  397. PrecedencePredicateTransition(target, arg1)
  398. ]
  399. def edgeFactory(self, atn:ATN, type:int, src:int, trg:int, arg1:int, arg2:int, arg3:int, sets:list):
  400. target = atn.states[trg]
  401. if type > len(self.edgeFactories) or self.edgeFactories[type] is None:
  402. raise Exception("The specified transition type: " + str(type) + " is not valid.")
  403. else:
  404. return self.edgeFactories[type](atn, src, trg, arg1, arg2, arg3, sets, target)
  405. stateFactories = [ lambda : None,
  406. lambda : BasicState(),
  407. lambda : RuleStartState(),
  408. lambda : BasicBlockStartState(),
  409. lambda : PlusBlockStartState(),
  410. lambda : StarBlockStartState(),
  411. lambda : TokensStartState(),
  412. lambda : RuleStopState(),
  413. lambda : BlockEndState(),
  414. lambda : StarLoopbackState(),
  415. lambda : StarLoopEntryState(),
  416. lambda : PlusLoopbackState(),
  417. lambda : LoopEndState()
  418. ]
  419. def stateFactory(self, type:int, ruleIndex:int):
  420. if type> len(self.stateFactories) or self.stateFactories[type] is None:
  421. raise Exception("The specified state type " + str(type) + " is not valid.")
  422. else:
  423. s = self.stateFactories[type]()
  424. if s is not None:
  425. s.ruleIndex = ruleIndex
  426. return s
  427. CHANNEL = 0 #The type of a {@link LexerChannelAction} action.
  428. CUSTOM = 1 #The type of a {@link LexerCustomAction} action.
  429. MODE = 2 #The type of a {@link LexerModeAction} action.
  430. MORE = 3 #The type of a {@link LexerMoreAction} action.
  431. POP_MODE = 4 #The type of a {@link LexerPopModeAction} action.
  432. PUSH_MODE = 5 #The type of a {@link LexerPushModeAction} action.
  433. SKIP = 6 #The type of a {@link LexerSkipAction} action.
  434. TYPE = 7 #The type of a {@link LexerTypeAction} action.
  435. actionFactories = [ lambda data1, data2: LexerChannelAction(data1),
  436. lambda data1, data2: LexerCustomAction(data1, data2),
  437. lambda data1, data2: LexerModeAction(data1),
  438. lambda data1, data2: LexerMoreAction.INSTANCE,
  439. lambda data1, data2: LexerPopModeAction.INSTANCE,
  440. lambda data1, data2: LexerPushModeAction(data1),
  441. lambda data1, data2: LexerSkipAction.INSTANCE,
  442. lambda data1, data2: LexerTypeAction(data1)
  443. ]
  444. def lexerActionFactory(self, type:int, data1:int, data2:int):
  445. if type > len(self.actionFactories) or self.actionFactories[type] is None:
  446. raise Exception("The specified lexer action type " + str(type) + " is not valid.")
  447. else:
  448. return self.actionFactories[type](data1, data2)