utils.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416
  1. import unicodedata
  2. import os
  3. from itertools import product
  4. from collections import deque
  5. from typing import Callable, Iterator, List, Optional, Tuple, Type, TypeVar, Union, Dict, Any, Sequence, Iterable, AbstractSet
  6. ###{standalone
  7. import sys, re
  8. import logging
  9. from dataclasses import dataclass
  10. from typing import Generic, AnyStr
  11. logger: logging.Logger = logging.getLogger("lark")
  12. logger.addHandler(logging.StreamHandler())
  13. # Set to highest level, since we have some warnings amongst the code
  14. # By default, we should not output any log messages
  15. logger.setLevel(logging.CRITICAL)
  16. NO_VALUE = object()
  17. T = TypeVar("T")
  18. def classify(seq: Iterable, key: Optional[Callable] = None, value: Optional[Callable] = None) -> Dict:
  19. d: Dict[Any, Any] = {}
  20. for item in seq:
  21. k = key(item) if (key is not None) else item
  22. v = value(item) if (value is not None) else item
  23. try:
  24. d[k].append(v)
  25. except KeyError:
  26. d[k] = [v]
  27. return d
  28. def _deserialize(data: Any, namespace: Dict[str, Any], memo: Dict) -> Any:
  29. if isinstance(data, dict):
  30. if '__type__' in data: # Object
  31. class_ = namespace[data['__type__']]
  32. return class_.deserialize(data, memo)
  33. elif '@' in data:
  34. return memo[data['@']]
  35. return {key:_deserialize(value, namespace, memo) for key, value in data.items()}
  36. elif isinstance(data, list):
  37. return [_deserialize(value, namespace, memo) for value in data]
  38. return data
  39. _T = TypeVar("_T", bound="Serialize")
  40. class Serialize:
  41. """Safe-ish serialization interface that doesn't rely on Pickle
  42. Attributes:
  43. __serialize_fields__ (List[str]): Fields (aka attributes) to serialize.
  44. __serialize_namespace__ (list): List of classes that deserialization is allowed to instantiate.
  45. Should include all field types that aren't builtin types.
  46. """
  47. def memo_serialize(self, types_to_memoize: List) -> Any:
  48. memo = SerializeMemoizer(types_to_memoize)
  49. return self.serialize(memo), memo.serialize()
  50. def serialize(self, memo = None) -> Dict[str, Any]:
  51. if memo and memo.in_types(self):
  52. return {'@': memo.memoized.get(self)}
  53. fields = getattr(self, '__serialize_fields__')
  54. res = {f: _serialize(getattr(self, f), memo) for f in fields}
  55. res['__type__'] = type(self).__name__
  56. if hasattr(self, '_serialize'):
  57. self._serialize(res, memo)
  58. return res
  59. @classmethod
  60. def deserialize(cls: Type[_T], data: Dict[str, Any], memo: Dict[int, Any]) -> _T:
  61. namespace = getattr(cls, '__serialize_namespace__', [])
  62. namespace = {c.__name__:c for c in namespace}
  63. fields = getattr(cls, '__serialize_fields__')
  64. if '@' in data:
  65. return memo[data['@']]
  66. inst = cls.__new__(cls)
  67. for f in fields:
  68. try:
  69. setattr(inst, f, _deserialize(data[f], namespace, memo))
  70. except KeyError as e:
  71. raise KeyError("Cannot find key for class", cls, e)
  72. if hasattr(inst, '_deserialize'):
  73. inst._deserialize()
  74. return inst
  75. class SerializeMemoizer(Serialize):
  76. "A version of serialize that memoizes objects to reduce space"
  77. __serialize_fields__ = 'memoized',
  78. def __init__(self, types_to_memoize: List) -> None:
  79. self.types_to_memoize = tuple(types_to_memoize)
  80. self.memoized = Enumerator()
  81. def in_types(self, value: Serialize) -> bool:
  82. return isinstance(value, self.types_to_memoize)
  83. def serialize(self) -> Dict[int, Any]: # type: ignore[override]
  84. return _serialize(self.memoized.reversed(), None)
  85. @classmethod
  86. def deserialize(cls, data: Dict[int, Any], namespace: Dict[str, Any], memo: Dict[Any, Any]) -> Dict[int, Any]: # type: ignore[override]
  87. return _deserialize(data, namespace, memo)
  88. try:
  89. import regex
  90. _has_regex = True
  91. except ImportError:
  92. _has_regex = False
  93. if sys.version_info >= (3, 11):
  94. import re._parser as sre_parse
  95. import re._constants as sre_constants
  96. else:
  97. import sre_parse
  98. import sre_constants
  99. categ_pattern = re.compile(r'\\p{[A-Za-z_]+}')
  100. def get_regexp_width(expr: str) -> Union[Tuple[int, int], List[int]]:
  101. if _has_regex:
  102. # Since `sre_parse` cannot deal with Unicode categories of the form `\p{Mn}`, we replace these with
  103. # a simple letter, which makes no difference as we are only trying to get the possible lengths of the regex
  104. # match here below.
  105. regexp_final = re.sub(categ_pattern, 'A', expr)
  106. else:
  107. if re.search(categ_pattern, expr):
  108. raise ImportError('`regex` module must be installed in order to use Unicode categories.', expr)
  109. regexp_final = expr
  110. try:
  111. # Fixed in next version (past 0.960) of typeshed
  112. return [int(x) for x in sre_parse.parse(regexp_final).getwidth()]
  113. except sre_constants.error:
  114. if not _has_regex:
  115. raise ValueError(expr)
  116. else:
  117. # sre_parse does not support the new features in regex. To not completely fail in that case,
  118. # we manually test for the most important info (whether the empty string is matched)
  119. c = regex.compile(regexp_final)
  120. # Python 3.11.7 introducded sre_parse.MAXWIDTH that is used instead of MAXREPEAT
  121. # See lark-parser/lark#1376 and python/cpython#109859
  122. MAXWIDTH = getattr(sre_parse, "MAXWIDTH", sre_constants.MAXREPEAT)
  123. if c.match('') is None:
  124. # MAXREPEAT is a none pickable subclass of int, therefore needs to be converted to enable caching
  125. return 1, int(MAXWIDTH)
  126. else:
  127. return 0, int(MAXWIDTH)
  128. @dataclass(frozen=True)
  129. class TextSlice(Generic[AnyStr]):
  130. """A view of a string or bytes object, between the start and end indices.
  131. Never creates a copy.
  132. Lark accepts instances of TextSlice as input (instead of a string),
  133. when the lexer is 'basic' or 'contextual'.
  134. Args:
  135. text (str or bytes): The text to slice.
  136. start (int): The start index. Negative indices are supported.
  137. end (int): The end index. Negative indices are supported.
  138. Raises:
  139. TypeError: If `text` is not a `str` or `bytes`.
  140. AssertionError: If `start` or `end` are out of bounds.
  141. Examples:
  142. >>> TextSlice("Hello, World!", 7, -1)
  143. TextSlice(text='Hello, World!', start=7, end=12)
  144. >>> TextSlice("Hello, World!", 7, None).count("o")
  145. 1
  146. """
  147. text: AnyStr
  148. start: int
  149. end: int
  150. def __post_init__(self):
  151. if not isinstance(self.text, (str, bytes)):
  152. raise TypeError("text must be str or bytes")
  153. if self.start < 0:
  154. object.__setattr__(self, 'start', self.start + len(self.text))
  155. assert self.start >=0
  156. if self.end is None:
  157. object.__setattr__(self, 'end', len(self.text))
  158. elif self.end < 0:
  159. object.__setattr__(self, 'end', self.end + len(self.text))
  160. assert self.end <= len(self.text)
  161. @classmethod
  162. def cast_from(cls, text: 'TextOrSlice') -> 'TextSlice[AnyStr]':
  163. if isinstance(text, TextSlice):
  164. return text
  165. return cls(text, 0, len(text))
  166. def is_complete_text(self):
  167. return self.start == 0 and self.end == len(self.text)
  168. def __len__(self):
  169. return self.end - self.start
  170. def count(self, substr: AnyStr):
  171. return self.text.count(substr, self.start, self.end)
  172. def rindex(self, substr: AnyStr):
  173. return self.text.rindex(substr, self.start, self.end)
  174. TextOrSlice = Union[AnyStr, 'TextSlice[AnyStr]']
  175. LarkInput = Union[AnyStr, TextSlice[AnyStr], Any]
  176. ###}
  177. _ID_START = 'Lu', 'Ll', 'Lt', 'Lm', 'Lo', 'Mn', 'Mc', 'Pc'
  178. _ID_CONTINUE = _ID_START + ('Nd', 'Nl',)
  179. def _test_unicode_category(s: str, categories: Sequence[str]) -> bool:
  180. if len(s) != 1:
  181. return all(_test_unicode_category(char, categories) for char in s)
  182. return s == '_' or unicodedata.category(s) in categories
  183. def is_id_continue(s: str) -> bool:
  184. """
  185. Checks if all characters in `s` are alphanumeric characters (Unicode standard, so diacritics, indian vowels, non-latin
  186. numbers, etc. all pass). Synonymous with a Python `ID_CONTINUE` identifier. See PEP 3131 for details.
  187. """
  188. return _test_unicode_category(s, _ID_CONTINUE)
  189. def is_id_start(s: str) -> bool:
  190. """
  191. Checks if all characters in `s` are alphabetic characters (Unicode standard, so diacritics, indian vowels, non-latin
  192. numbers, etc. all pass). Synonymous with a Python `ID_START` identifier. See PEP 3131 for details.
  193. """
  194. return _test_unicode_category(s, _ID_START)
  195. def dedup_list(l: Iterable[T]) -> List[T]:
  196. """Given a list (l) will removing duplicates from the list,
  197. preserving the original order of the list. Assumes that
  198. the list entries are hashable."""
  199. return list(dict.fromkeys(l))
  200. class Enumerator(Serialize):
  201. def __init__(self) -> None:
  202. self.enums: Dict[Any, int] = {}
  203. def get(self, item) -> int:
  204. if item not in self.enums:
  205. self.enums[item] = len(self.enums)
  206. return self.enums[item]
  207. def __len__(self):
  208. return len(self.enums)
  209. def reversed(self) -> Dict[int, Any]:
  210. r = {v: k for k, v in self.enums.items()}
  211. assert len(r) == len(self.enums)
  212. return r
  213. def combine_alternatives(lists):
  214. """
  215. Accepts a list of alternatives, and enumerates all their possible concatenations.
  216. Examples:
  217. >>> combine_alternatives([range(2), [4,5]])
  218. [[0, 4], [0, 5], [1, 4], [1, 5]]
  219. >>> combine_alternatives(["abc", "xy", '$'])
  220. [['a', 'x', '$'], ['a', 'y', '$'], ['b', 'x', '$'], ['b', 'y', '$'], ['c', 'x', '$'], ['c', 'y', '$']]
  221. >>> combine_alternatives([])
  222. [[]]
  223. """
  224. if not lists:
  225. return [[]]
  226. assert all(l for l in lists), lists
  227. return list(product(*lists))
  228. try:
  229. import atomicwrites
  230. _has_atomicwrites = True
  231. except ImportError:
  232. _has_atomicwrites = False
  233. class FS:
  234. exists = staticmethod(os.path.exists)
  235. @staticmethod
  236. def open(name, mode="r", **kwargs):
  237. if _has_atomicwrites and "w" in mode:
  238. return atomicwrites.atomic_write(name, mode=mode, overwrite=True, **kwargs)
  239. else:
  240. return open(name, mode, **kwargs)
  241. class fzset(frozenset):
  242. def __repr__(self):
  243. return '{%s}' % ', '.join(map(repr, self))
  244. def classify_bool(seq: Iterable, pred: Callable) -> Any:
  245. false_elems = []
  246. true_elems = [elem for elem in seq if pred(elem) or false_elems.append(elem)] # type: ignore[func-returns-value]
  247. return true_elems, false_elems
  248. def bfs(initial: Iterable, expand: Callable) -> Iterator:
  249. open_q = deque(list(initial))
  250. visited = set(open_q)
  251. while open_q:
  252. node = open_q.popleft()
  253. yield node
  254. for next_node in expand(node):
  255. if next_node not in visited:
  256. visited.add(next_node)
  257. open_q.append(next_node)
  258. def bfs_all_unique(initial, expand):
  259. "bfs, but doesn't keep track of visited (aka seen), because there can be no repetitions"
  260. open_q = deque(list(initial))
  261. while open_q:
  262. node = open_q.popleft()
  263. yield node
  264. open_q += expand(node)
  265. def _serialize(value: Any, memo: Optional[SerializeMemoizer]) -> Any:
  266. if isinstance(value, Serialize):
  267. return value.serialize(memo)
  268. elif isinstance(value, list):
  269. return [_serialize(elem, memo) for elem in value]
  270. elif isinstance(value, frozenset):
  271. return list(value) # TODO reversible?
  272. elif isinstance(value, dict):
  273. return {key:_serialize(elem, memo) for key, elem in value.items()}
  274. # assert value is None or isinstance(value, (int, float, str, tuple)), value
  275. return value
  276. def small_factors(n: int, max_factor: int) -> List[Tuple[int, int]]:
  277. """
  278. Splits n up into smaller factors and summands <= max_factor.
  279. Returns a list of [(a, b), ...]
  280. so that the following code returns n:
  281. n = 1
  282. for a, b in values:
  283. n = n * a + b
  284. Currently, we also keep a + b <= max_factor, but that might change
  285. """
  286. assert n >= 0
  287. assert max_factor > 2
  288. if n <= max_factor:
  289. return [(n, 0)]
  290. for a in range(max_factor, 1, -1):
  291. r, b = divmod(n, a)
  292. if a + b <= max_factor:
  293. return small_factors(r, max_factor) + [(a, b)]
  294. assert False, "Failed to factorize %s" % n
  295. class OrderedSet(AbstractSet[T]):
  296. """A minimal OrderedSet implementation, using a dictionary.
  297. (relies on the dictionary being ordered)
  298. """
  299. def __init__(self, items: Iterable[T] =()):
  300. self.d = dict.fromkeys(items)
  301. def __contains__(self, item: Any) -> bool:
  302. return item in self.d
  303. def add(self, item: T):
  304. self.d[item] = None
  305. def __iter__(self) -> Iterator[T]:
  306. return iter(self.d)
  307. def remove(self, item: T):
  308. del self.d[item]
  309. def __bool__(self):
  310. return bool(self.d)
  311. def __len__(self) -> int:
  312. return len(self.d)
  313. def __repr__(self):
  314. return f"{type(self).__name__}({', '.join(map(repr,self))})"