| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180 |
- #
- # Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
- # Use of this file is governed by the BSD 3-clause license that
- # can be found in the LICENSE.txt file in the project root.
- #
- from io import StringIO
- from antlr4.Token import Token
- # need forward declarations
- IntervalSet = None
- class IntervalSet(object):
- __slots__ = ('intervals', 'readonly')
- def __init__(self):
- self.intervals = None
- self.readonly = False
- def __iter__(self):
- if self.intervals is not None:
- for i in self.intervals:
- for c in i:
- yield c
- def __getitem__(self, item):
- i = 0
- for k in self:
- if i==item:
- return k
- else:
- i += 1
- return Token.INVALID_TYPE
- def addOne(self, v:int):
- self.addRange(range(v, v+1))
- def addRange(self, v:range):
- if self.intervals is None:
- self.intervals = list()
- self.intervals.append(v)
- else:
- # find insert pos
- k = 0
- for i in self.intervals:
- # distinct range -> insert
- if v.stop<i.start:
- self.intervals.insert(k, v)
- return
- # contiguous range -> adjust
- elif v.stop==i.start:
- self.intervals[k] = range(v.start, i.stop)
- return
- # overlapping range -> adjust and reduce
- elif v.start<=i.stop:
- self.intervals[k] = range(min(i.start,v.start), max(i.stop,v.stop))
- self.reduce(k)
- return
- k += 1
- # greater than any existing
- self.intervals.append(v)
- def addSet(self, other:IntervalSet):
- if other.intervals is not None:
- for i in other.intervals:
- self.addRange(i)
- return self
- def reduce(self, k:int):
- # only need to reduce if k is not the last
- if k<len(self.intervals)-1:
- l = self.intervals[k]
- r = self.intervals[k+1]
- # if r contained in l
- if l.stop >= r.stop:
- self.intervals.pop(k+1)
- self.reduce(k)
- elif l.stop >= r.start:
- self.intervals[k] = range(l.start, r.stop)
- self.intervals.pop(k+1)
- def complement(self, start, stop):
- result = IntervalSet()
- result.addRange(range(start,stop+1))
- for i in self.intervals:
- result.removeRange(i)
- return result
- def __contains__(self, item):
- if self.intervals is None:
- return False
- else:
- return any(item in i for i in self.intervals)
- def __len__(self):
- return sum(len(i) for i in self.intervals)
- def removeRange(self, v):
- if v.start==v.stop-1:
- self.removeOne(v.start)
- elif self.intervals is not None:
- k = 0
- for i in self.intervals:
- # intervals are ordered
- if v.stop<=i.start:
- return
- # check for including range, split it
- elif v.start>i.start and v.stop<i.stop:
- self.intervals[k] = range(i.start, v.start)
- x = range(v.stop, i.stop)
- self.intervals.insert(k, x)
- return
- # check for included range, remove it
- elif v.start<=i.start and v.stop>=i.stop:
- self.intervals.pop(k)
- k -= 1 # need another pass
- # check for lower boundary
- elif v.start<i.stop:
- self.intervals[k] = range(i.start, v.start)
- # check for upper boundary
- elif v.stop<i.stop:
- self.intervals[k] = range(v.stop, i.stop)
- k += 1
- def removeOne(self, v):
- if self.intervals is not None:
- k = 0
- for i in self.intervals:
- # intervals is ordered
- if v<i.start:
- return
- # check for single value range
- elif v==i.start and v==i.stop-1:
- self.intervals.pop(k)
- return
- # check for lower boundary
- elif v==i.start:
- self.intervals[k] = range(i.start+1, i.stop)
- return
- # check for upper boundary
- elif v==i.stop-1:
- self.intervals[k] = range(i.start, i.stop-1)
- return
- # split existing range
- elif v<i.stop-1:
- x = range(i.start, v)
- self.intervals[k] = range(v + 1, i.stop)
- self.intervals.insert(k, x)
- return
- k += 1
- def toString(self, literalNames:list, symbolicNames:list):
- if self.intervals is None:
- return "{}"
- with StringIO() as buf:
- if len(self)>1:
- buf.write("{")
- first = True
- for i in self.intervals:
- for j in i:
- if not first:
- buf.write(", ")
- buf.write(self.elementName(literalNames, symbolicNames, j))
- first = False
- if len(self)>1:
- buf.write("}")
- return buf.getvalue()
- def elementName(self, literalNames:list, symbolicNames:list, a:int):
- if a==Token.EOF:
- return "<EOF>"
- elif a==Token.EPSILON:
- return "<EPSILON>"
- else:
- if a<len(literalNames) and literalNames[a] != "<INVALID>":
- return literalNames[a]
- if a<len(symbolicNames):
- return symbolicNames[a]
- return "<UNKNOWN>"
|