structs.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. from __future__ import annotations
  2. import itertools
  3. from collections import namedtuple
  4. from typing import (
  5. TYPE_CHECKING,
  6. Callable,
  7. Generic,
  8. Iterable,
  9. Iterator,
  10. Mapping,
  11. NamedTuple,
  12. Sequence,
  13. TypeVar,
  14. Union,
  15. )
  16. KT = TypeVar("KT") # Identifier.
  17. RT = TypeVar("RT") # Requirement.
  18. CT = TypeVar("CT") # Candidate.
  19. Matches = Union[Iterable[CT], Callable[[], Iterable[CT]]]
  20. if TYPE_CHECKING:
  21. from .resolvers.criterion import Criterion
  22. class RequirementInformation(NamedTuple, Generic[RT, CT]):
  23. requirement: RT
  24. parent: CT | None
  25. class State(NamedTuple, Generic[RT, CT, KT]):
  26. """Resolution state in a round."""
  27. mapping: dict[KT, CT]
  28. criteria: dict[KT, Criterion[RT, CT]]
  29. backtrack_causes: list[RequirementInformation[RT, CT]]
  30. else:
  31. RequirementInformation = namedtuple(
  32. "RequirementInformation", ["requirement", "parent"]
  33. )
  34. State = namedtuple("State", ["mapping", "criteria", "backtrack_causes"])
  35. class DirectedGraph(Generic[KT]):
  36. """A graph structure with directed edges."""
  37. def __init__(self) -> None:
  38. self._vertices: set[KT] = set()
  39. self._forwards: dict[KT, set[KT]] = {} # <key> -> Set[<key>]
  40. self._backwards: dict[KT, set[KT]] = {} # <key> -> Set[<key>]
  41. def __iter__(self) -> Iterator[KT]:
  42. return iter(self._vertices)
  43. def __len__(self) -> int:
  44. return len(self._vertices)
  45. def __contains__(self, key: KT) -> bool:
  46. return key in self._vertices
  47. def copy(self) -> DirectedGraph[KT]:
  48. """Return a shallow copy of this graph."""
  49. other = type(self)()
  50. other._vertices = set(self._vertices)
  51. other._forwards = {k: set(v) for k, v in self._forwards.items()}
  52. other._backwards = {k: set(v) for k, v in self._backwards.items()}
  53. return other
  54. def add(self, key: KT) -> None:
  55. """Add a new vertex to the graph."""
  56. if key in self._vertices:
  57. raise ValueError("vertex exists")
  58. self._vertices.add(key)
  59. self._forwards[key] = set()
  60. self._backwards[key] = set()
  61. def remove(self, key: KT) -> None:
  62. """Remove a vertex from the graph, disconnecting all edges from/to it."""
  63. self._vertices.remove(key)
  64. for f in self._forwards.pop(key):
  65. self._backwards[f].remove(key)
  66. for t in self._backwards.pop(key):
  67. self._forwards[t].remove(key)
  68. def connected(self, f: KT, t: KT) -> bool:
  69. return f in self._backwards[t] and t in self._forwards[f]
  70. def connect(self, f: KT, t: KT) -> None:
  71. """Connect two existing vertices.
  72. Nothing happens if the vertices are already connected.
  73. """
  74. if t not in self._vertices:
  75. raise KeyError(t)
  76. self._forwards[f].add(t)
  77. self._backwards[t].add(f)
  78. def iter_edges(self) -> Iterator[tuple[KT, KT]]:
  79. for f, children in self._forwards.items():
  80. for t in children:
  81. yield f, t
  82. def iter_children(self, key: KT) -> Iterator[KT]:
  83. return iter(self._forwards[key])
  84. def iter_parents(self, key: KT) -> Iterator[KT]:
  85. return iter(self._backwards[key])
  86. class IteratorMapping(Mapping[KT, Iterator[CT]], Generic[RT, CT, KT]):
  87. def __init__(
  88. self,
  89. mapping: Mapping[KT, RT],
  90. accessor: Callable[[RT], Iterable[CT]],
  91. appends: Mapping[KT, Iterable[CT]] | None = None,
  92. ) -> None:
  93. self._mapping = mapping
  94. self._accessor = accessor
  95. self._appends: Mapping[KT, Iterable[CT]] = appends or {}
  96. def __repr__(self) -> str:
  97. return "IteratorMapping({!r}, {!r}, {!r})".format(
  98. self._mapping,
  99. self._accessor,
  100. self._appends,
  101. )
  102. def __bool__(self) -> bool:
  103. return bool(self._mapping or self._appends)
  104. def __contains__(self, key: object) -> bool:
  105. return key in self._mapping or key in self._appends
  106. def __getitem__(self, k: KT) -> Iterator[CT]:
  107. try:
  108. v = self._mapping[k]
  109. except KeyError:
  110. return iter(self._appends[k])
  111. return itertools.chain(self._accessor(v), self._appends.get(k, ()))
  112. def __iter__(self) -> Iterator[KT]:
  113. more = (k for k in self._appends if k not in self._mapping)
  114. return itertools.chain(self._mapping, more)
  115. def __len__(self) -> int:
  116. more = sum(1 for k in self._appends if k not in self._mapping)
  117. return len(self._mapping) + more
  118. class _FactoryIterableView(Iterable[RT]):
  119. """Wrap an iterator factory returned by `find_matches()`.
  120. Calling `iter()` on this class would invoke the underlying iterator
  121. factory, making it a "collection with ordering" that can be iterated
  122. through multiple times, but lacks random access methods presented in
  123. built-in Python sequence types.
  124. """
  125. def __init__(self, factory: Callable[[], Iterable[RT]]) -> None:
  126. self._factory = factory
  127. self._iterable: Iterable[RT] | None = None
  128. def __repr__(self) -> str:
  129. return f"{type(self).__name__}({list(self)})"
  130. def __bool__(self) -> bool:
  131. try:
  132. next(iter(self))
  133. except StopIteration:
  134. return False
  135. return True
  136. def __iter__(self) -> Iterator[RT]:
  137. iterable = self._factory() if self._iterable is None else self._iterable
  138. self._iterable, current = itertools.tee(iterable)
  139. return current
  140. class _SequenceIterableView(Iterable[RT]):
  141. """Wrap an iterable returned by find_matches().
  142. This is essentially just a proxy to the underlying sequence that provides
  143. the same interface as `_FactoryIterableView`.
  144. """
  145. def __init__(self, sequence: Sequence[RT]):
  146. self._sequence = sequence
  147. def __repr__(self) -> str:
  148. return f"{type(self).__name__}({self._sequence})"
  149. def __bool__(self) -> bool:
  150. return bool(self._sequence)
  151. def __iter__(self) -> Iterator[RT]:
  152. return iter(self._sequence)
  153. def build_iter_view(matches: Matches[CT]) -> Iterable[CT]:
  154. """Build an iterable view from the value returned by `find_matches()`."""
  155. if callable(matches):
  156. return _FactoryIterableView(matches)
  157. if not isinstance(matches, Sequence):
  158. matches = list(matches)
  159. return _SequenceIterableView(matches)
  160. IterableView = Iterable