IntervalSet.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. #
  2. # Copyright (c) 2012-2017 The ANTLR Project. All rights reserved.
  3. # Use of this file is governed by the BSD 3-clause license that
  4. # can be found in the LICENSE.txt file in the project root.
  5. #
  6. from io import StringIO
  7. from antlr4.Token import Token
  8. # need forward declarations
  9. IntervalSet = None
  10. class IntervalSet(object):
  11. __slots__ = ('intervals', 'readonly')
  12. def __init__(self):
  13. self.intervals = None
  14. self.readonly = False
  15. def __iter__(self):
  16. if self.intervals is not None:
  17. for i in self.intervals:
  18. for c in i:
  19. yield c
  20. def __getitem__(self, item):
  21. i = 0
  22. for k in self:
  23. if i==item:
  24. return k
  25. else:
  26. i += 1
  27. return Token.INVALID_TYPE
  28. def addOne(self, v:int):
  29. self.addRange(range(v, v+1))
  30. def addRange(self, v:range):
  31. if self.intervals is None:
  32. self.intervals = list()
  33. self.intervals.append(v)
  34. else:
  35. # find insert pos
  36. k = 0
  37. for i in self.intervals:
  38. # distinct range -> insert
  39. if v.stop<i.start:
  40. self.intervals.insert(k, v)
  41. return
  42. # contiguous range -> adjust
  43. elif v.stop==i.start:
  44. self.intervals[k] = range(v.start, i.stop)
  45. return
  46. # overlapping range -> adjust and reduce
  47. elif v.start<=i.stop:
  48. self.intervals[k] = range(min(i.start,v.start), max(i.stop,v.stop))
  49. self.reduce(k)
  50. return
  51. k += 1
  52. # greater than any existing
  53. self.intervals.append(v)
  54. def addSet(self, other:IntervalSet):
  55. if other.intervals is not None:
  56. for i in other.intervals:
  57. self.addRange(i)
  58. return self
  59. def reduce(self, k:int):
  60. # only need to reduce if k is not the last
  61. if k<len(self.intervals)-1:
  62. l = self.intervals[k]
  63. r = self.intervals[k+1]
  64. # if r contained in l
  65. if l.stop >= r.stop:
  66. self.intervals.pop(k+1)
  67. self.reduce(k)
  68. elif l.stop >= r.start:
  69. self.intervals[k] = range(l.start, r.stop)
  70. self.intervals.pop(k+1)
  71. def complement(self, start, stop):
  72. result = IntervalSet()
  73. result.addRange(range(start,stop+1))
  74. for i in self.intervals:
  75. result.removeRange(i)
  76. return result
  77. def __contains__(self, item):
  78. if self.intervals is None:
  79. return False
  80. else:
  81. return any(item in i for i in self.intervals)
  82. def __len__(self):
  83. return sum(len(i) for i in self.intervals)
  84. def removeRange(self, v):
  85. if v.start==v.stop-1:
  86. self.removeOne(v.start)
  87. elif self.intervals is not None:
  88. k = 0
  89. for i in self.intervals:
  90. # intervals are ordered
  91. if v.stop<=i.start:
  92. return
  93. # check for including range, split it
  94. elif v.start>i.start and v.stop<i.stop:
  95. self.intervals[k] = range(i.start, v.start)
  96. x = range(v.stop, i.stop)
  97. self.intervals.insert(k, x)
  98. return
  99. # check for included range, remove it
  100. elif v.start<=i.start and v.stop>=i.stop:
  101. self.intervals.pop(k)
  102. k -= 1 # need another pass
  103. # check for lower boundary
  104. elif v.start<i.stop:
  105. self.intervals[k] = range(i.start, v.start)
  106. # check for upper boundary
  107. elif v.stop<i.stop:
  108. self.intervals[k] = range(v.stop, i.stop)
  109. k += 1
  110. def removeOne(self, v):
  111. if self.intervals is not None:
  112. k = 0
  113. for i in self.intervals:
  114. # intervals is ordered
  115. if v<i.start:
  116. return
  117. # check for single value range
  118. elif v==i.start and v==i.stop-1:
  119. self.intervals.pop(k)
  120. return
  121. # check for lower boundary
  122. elif v==i.start:
  123. self.intervals[k] = range(i.start+1, i.stop)
  124. return
  125. # check for upper boundary
  126. elif v==i.stop-1:
  127. self.intervals[k] = range(i.start, i.stop-1)
  128. return
  129. # split existing range
  130. elif v<i.stop-1:
  131. x = range(i.start, v)
  132. self.intervals[k] = range(v + 1, i.stop)
  133. self.intervals.insert(k, x)
  134. return
  135. k += 1
  136. def toString(self, literalNames:list, symbolicNames:list):
  137. if self.intervals is None:
  138. return "{}"
  139. with StringIO() as buf:
  140. if len(self)>1:
  141. buf.write("{")
  142. first = True
  143. for i in self.intervals:
  144. for j in i:
  145. if not first:
  146. buf.write(", ")
  147. buf.write(self.elementName(literalNames, symbolicNames, j))
  148. first = False
  149. if len(self)>1:
  150. buf.write("}")
  151. return buf.getvalue()
  152. def elementName(self, literalNames:list, symbolicNames:list, a:int):
  153. if a==Token.EOF:
  154. return "<EOF>"
  155. elif a==Token.EPSILON:
  156. return "<EPSILON>"
  157. else:
  158. if a<len(literalNames) and literalNames[a] != "<INVALID>":
  159. return literalNames[a]
  160. if a<len(symbolicNames):
  161. return symbolicNames[a]
  162. return "<UNKNOWN>"