resolution.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. from __future__ import annotations
  2. import collections
  3. import itertools
  4. import operator
  5. from typing import TYPE_CHECKING, Generic
  6. from ..structs import (
  7. CT,
  8. KT,
  9. RT,
  10. DirectedGraph,
  11. IterableView,
  12. IteratorMapping,
  13. RequirementInformation,
  14. State,
  15. build_iter_view,
  16. )
  17. from .abstract import AbstractResolver, Result
  18. from .criterion import Criterion
  19. from .exceptions import (
  20. InconsistentCandidate,
  21. RequirementsConflicted,
  22. ResolutionImpossible,
  23. ResolutionTooDeep,
  24. ResolverException,
  25. )
  26. if TYPE_CHECKING:
  27. from collections.abc import Collection, Iterable, Mapping
  28. from ..providers import AbstractProvider, Preference
  29. from ..reporters import BaseReporter
  30. _OPTIMISTIC_BACKJUMPING_RATIO: float = 0.1
  31. def _build_result(state: State[RT, CT, KT]) -> Result[RT, CT, KT]:
  32. mapping = state.mapping
  33. all_keys: dict[int, KT | None] = {id(v): k for k, v in mapping.items()}
  34. all_keys[id(None)] = None
  35. graph: DirectedGraph[KT | None] = DirectedGraph()
  36. graph.add(None) # Sentinel as root dependencies' parent.
  37. connected: set[KT | None] = {None}
  38. for key, criterion in state.criteria.items():
  39. if not _has_route_to_root(state.criteria, key, all_keys, connected):
  40. continue
  41. if key not in graph:
  42. graph.add(key)
  43. for p in criterion.iter_parent():
  44. try:
  45. pkey = all_keys[id(p)]
  46. except KeyError:
  47. continue
  48. if pkey not in graph:
  49. graph.add(pkey)
  50. graph.connect(pkey, key)
  51. return Result(
  52. mapping={k: v for k, v in mapping.items() if k in connected},
  53. graph=graph,
  54. criteria=state.criteria,
  55. )
  56. class Resolution(Generic[RT, CT, KT]):
  57. """Stateful resolution object.
  58. This is designed as a one-off object that holds information to kick start
  59. the resolution process, and holds the results afterwards.
  60. """
  61. def __init__(
  62. self,
  63. provider: AbstractProvider[RT, CT, KT],
  64. reporter: BaseReporter[RT, CT, KT],
  65. ) -> None:
  66. self._p = provider
  67. self._r = reporter
  68. self._states: list[State[RT, CT, KT]] = []
  69. # Optimistic backjumping variables
  70. self._optimistic_backjumping_ratio = _OPTIMISTIC_BACKJUMPING_RATIO
  71. self._save_states: list[State[RT, CT, KT]] | None = None
  72. self._optimistic_start_round: int | None = None
  73. @property
  74. def state(self) -> State[RT, CT, KT]:
  75. try:
  76. return self._states[-1]
  77. except IndexError as e:
  78. raise AttributeError("state") from e
  79. def _push_new_state(self) -> None:
  80. """Push a new state into history.
  81. This new state will be used to hold resolution results of the next
  82. coming round.
  83. """
  84. base = self._states[-1]
  85. state = State(
  86. mapping=base.mapping.copy(),
  87. criteria=base.criteria.copy(),
  88. backtrack_causes=base.backtrack_causes[:],
  89. )
  90. self._states.append(state)
  91. def _add_to_criteria(
  92. self,
  93. criteria: dict[KT, Criterion[RT, CT]],
  94. requirement: RT,
  95. parent: CT | None,
  96. ) -> None:
  97. self._r.adding_requirement(requirement=requirement, parent=parent)
  98. identifier = self._p.identify(requirement_or_candidate=requirement)
  99. criterion = criteria.get(identifier)
  100. if criterion:
  101. incompatibilities = list(criterion.incompatibilities)
  102. else:
  103. incompatibilities = []
  104. matches = self._p.find_matches(
  105. identifier=identifier,
  106. requirements=IteratorMapping(
  107. criteria,
  108. operator.methodcaller("iter_requirement"),
  109. {identifier: [requirement]},
  110. ),
  111. incompatibilities=IteratorMapping(
  112. criteria,
  113. operator.attrgetter("incompatibilities"),
  114. {identifier: incompatibilities},
  115. ),
  116. )
  117. if criterion:
  118. information = list(criterion.information)
  119. information.append(RequirementInformation(requirement, parent))
  120. else:
  121. information = [RequirementInformation(requirement, parent)]
  122. criterion = Criterion(
  123. candidates=build_iter_view(matches),
  124. information=information,
  125. incompatibilities=incompatibilities,
  126. )
  127. if not criterion.candidates:
  128. raise RequirementsConflicted(criterion)
  129. criteria[identifier] = criterion
  130. def _remove_information_from_criteria(
  131. self, criteria: dict[KT, Criterion[RT, CT]], parents: Collection[KT]
  132. ) -> None:
  133. """Remove information from parents of criteria.
  134. Concretely, removes all values from each criterion's ``information``
  135. field that have one of ``parents`` as provider of the requirement.
  136. :param criteria: The criteria to update.
  137. :param parents: Identifiers for which to remove information from all criteria.
  138. """
  139. if not parents:
  140. return
  141. for key, criterion in criteria.items():
  142. criteria[key] = Criterion(
  143. criterion.candidates,
  144. [
  145. information
  146. for information in criterion.information
  147. if (
  148. information.parent is None
  149. or self._p.identify(information.parent) not in parents
  150. )
  151. ],
  152. criterion.incompatibilities,
  153. )
  154. def _get_preference(self, name: KT) -> Preference:
  155. return self._p.get_preference(
  156. identifier=name,
  157. resolutions=self.state.mapping,
  158. candidates=IteratorMapping(
  159. self.state.criteria,
  160. operator.attrgetter("candidates"),
  161. ),
  162. information=IteratorMapping(
  163. self.state.criteria,
  164. operator.attrgetter("information"),
  165. ),
  166. backtrack_causes=self.state.backtrack_causes,
  167. )
  168. def _is_current_pin_satisfying(
  169. self, name: KT, criterion: Criterion[RT, CT]
  170. ) -> bool:
  171. try:
  172. current_pin = self.state.mapping[name]
  173. except KeyError:
  174. return False
  175. return all(
  176. self._p.is_satisfied_by(requirement=r, candidate=current_pin)
  177. for r in criterion.iter_requirement()
  178. )
  179. def _get_updated_criteria(self, candidate: CT) -> dict[KT, Criterion[RT, CT]]:
  180. criteria = self.state.criteria.copy()
  181. for requirement in self._p.get_dependencies(candidate=candidate):
  182. self._add_to_criteria(criteria, requirement, parent=candidate)
  183. return criteria
  184. def _attempt_to_pin_criterion(self, name: KT) -> list[Criterion[RT, CT]]:
  185. criterion = self.state.criteria[name]
  186. causes: list[Criterion[RT, CT]] = []
  187. for candidate in criterion.candidates:
  188. try:
  189. criteria = self._get_updated_criteria(candidate)
  190. except RequirementsConflicted as e:
  191. self._r.rejecting_candidate(e.criterion, candidate)
  192. causes.append(e.criterion)
  193. continue
  194. # Check the newly-pinned candidate actually works. This should
  195. # always pass under normal circumstances, but in the case of a
  196. # faulty provider, we will raise an error to notify the implementer
  197. # to fix find_matches() and/or is_satisfied_by().
  198. satisfied = all(
  199. self._p.is_satisfied_by(requirement=r, candidate=candidate)
  200. for r in criterion.iter_requirement()
  201. )
  202. if not satisfied:
  203. raise InconsistentCandidate(candidate, criterion)
  204. self._r.pinning(candidate=candidate)
  205. self.state.criteria.update(criteria)
  206. # Put newly-pinned candidate at the end. This is essential because
  207. # backtracking looks at this mapping to get the last pin.
  208. self.state.mapping.pop(name, None)
  209. self.state.mapping[name] = candidate
  210. return []
  211. # All candidates tried, nothing works. This criterion is a dead
  212. # end, signal for backtracking.
  213. return causes
  214. def _patch_criteria(
  215. self, incompatibilities_from_broken: list[tuple[KT, list[CT]]]
  216. ) -> bool:
  217. # Create a new state from the last known-to-work one, and apply
  218. # the previously gathered incompatibility information.
  219. for k, incompatibilities in incompatibilities_from_broken:
  220. if not incompatibilities:
  221. continue
  222. try:
  223. criterion = self.state.criteria[k]
  224. except KeyError:
  225. continue
  226. matches = self._p.find_matches(
  227. identifier=k,
  228. requirements=IteratorMapping(
  229. self.state.criteria,
  230. operator.methodcaller("iter_requirement"),
  231. ),
  232. incompatibilities=IteratorMapping(
  233. self.state.criteria,
  234. operator.attrgetter("incompatibilities"),
  235. {k: incompatibilities},
  236. ),
  237. )
  238. candidates: IterableView[CT] = build_iter_view(matches)
  239. if not candidates:
  240. return False
  241. incompatibilities.extend(criterion.incompatibilities)
  242. self.state.criteria[k] = Criterion(
  243. candidates=candidates,
  244. information=list(criterion.information),
  245. incompatibilities=incompatibilities,
  246. )
  247. return True
  248. def _save_state(self) -> None:
  249. """Save states for potential rollback if optimistic backjumping fails."""
  250. if self._save_states is None:
  251. self._save_states = [
  252. State(
  253. mapping=s.mapping.copy(),
  254. criteria=s.criteria.copy(),
  255. backtrack_causes=s.backtrack_causes[:],
  256. )
  257. for s in self._states
  258. ]
  259. def _rollback_states(self) -> None:
  260. """Rollback states and disable optimistic backjumping."""
  261. self._optimistic_backjumping_ratio = 0.0
  262. if self._save_states:
  263. self._states = self._save_states
  264. self._save_states = None
  265. def _backjump(self, causes: list[RequirementInformation[RT, CT]]) -> bool:
  266. """Perform backjumping.
  267. When we enter here, the stack is like this::
  268. [ state Z ]
  269. [ state Y ]
  270. [ state X ]
  271. .... earlier states are irrelevant.
  272. 1. No pins worked for Z, so it does not have a pin.
  273. 2. We want to reset state Y to unpinned, and pin another candidate.
  274. 3. State X holds what state Y was before the pin, but does not
  275. have the incompatibility information gathered in state Y.
  276. Each iteration of the loop will:
  277. 1. Identify Z. The incompatibility is not always caused by the latest
  278. state. For example, given three requirements A, B and C, with
  279. dependencies A1, B1 and C1, where A1 and B1 are incompatible: the
  280. last state might be related to C, so we want to discard the
  281. previous state.
  282. 2. Discard Z.
  283. 3. Discard Y but remember its incompatibility information gathered
  284. previously, and the failure we're dealing with right now.
  285. 4. Push a new state Y' based on X, and apply the incompatibility
  286. information from Y to Y'.
  287. 5a. If this causes Y' to conflict, we need to backtrack again. Make Y'
  288. the new Z and go back to step 2.
  289. 5b. If the incompatibilities apply cleanly, end backtracking.
  290. """
  291. incompatible_reqs: Iterable[CT | RT] = itertools.chain(
  292. (c.parent for c in causes if c.parent is not None),
  293. (c.requirement for c in causes),
  294. )
  295. incompatible_deps = {self._p.identify(r) for r in incompatible_reqs}
  296. while len(self._states) >= 3:
  297. # Remove the state that triggered backtracking.
  298. del self._states[-1]
  299. # Optimistically backtrack to a state that caused the incompatibility
  300. broken_state = self.state
  301. while True:
  302. # Retrieve the last candidate pin and known incompatibilities.
  303. try:
  304. broken_state = self._states.pop()
  305. name, candidate = broken_state.mapping.popitem()
  306. except (IndexError, KeyError):
  307. raise ResolutionImpossible(causes) from None
  308. if (
  309. not self._optimistic_backjumping_ratio
  310. and name not in incompatible_deps
  311. ):
  312. # For safe backjumping only backjump if the current dependency
  313. # is not the same as the incompatible dependency
  314. break
  315. # On the first time a non-safe backjump is done the state
  316. # is saved so we can restore it later if the resolution fails
  317. if (
  318. self._optimistic_backjumping_ratio
  319. and self._save_states is None
  320. and name not in incompatible_deps
  321. ):
  322. self._save_state()
  323. # If the current dependencies and the incompatible dependencies
  324. # are overlapping then we have likely found a cause of the
  325. # incompatibility
  326. current_dependencies = {
  327. self._p.identify(d) for d in self._p.get_dependencies(candidate)
  328. }
  329. if not current_dependencies.isdisjoint(incompatible_deps):
  330. break
  331. # Fallback: We should not backtrack to the point where
  332. # broken_state.mapping is empty, so stop backtracking for
  333. # a chance for the resolution to recover
  334. if not broken_state.mapping:
  335. break
  336. # Guard: We need at least two state to remain to both
  337. # backtrack and push a new state
  338. if len(self._states) <= 1:
  339. raise ResolutionImpossible(causes)
  340. incompatibilities_from_broken = [
  341. (k, list(v.incompatibilities)) for k, v in broken_state.criteria.items()
  342. ]
  343. # Also mark the newly known incompatibility.
  344. incompatibilities_from_broken.append((name, [candidate]))
  345. self._push_new_state()
  346. success = self._patch_criteria(incompatibilities_from_broken)
  347. # It works! Let's work on this new state.
  348. if success:
  349. return True
  350. # State does not work after applying known incompatibilities.
  351. # Try the still previous state.
  352. # No way to backtrack anymore.
  353. return False
  354. def _extract_causes(
  355. self, criteron: list[Criterion[RT, CT]]
  356. ) -> list[RequirementInformation[RT, CT]]:
  357. """Extract causes from list of criterion and deduplicate"""
  358. return list({id(i): i for c in criteron for i in c.information}.values())
  359. def resolve(self, requirements: Iterable[RT], max_rounds: int) -> State[RT, CT, KT]:
  360. if self._states:
  361. raise RuntimeError("already resolved")
  362. self._r.starting()
  363. # Initialize the root state.
  364. self._states = [
  365. State(
  366. mapping=collections.OrderedDict(),
  367. criteria={},
  368. backtrack_causes=[],
  369. )
  370. ]
  371. for r in requirements:
  372. try:
  373. self._add_to_criteria(self.state.criteria, r, parent=None)
  374. except RequirementsConflicted as e:
  375. raise ResolutionImpossible(e.criterion.information) from e
  376. # The root state is saved as a sentinel so the first ever pin can have
  377. # something to backtrack to if it fails. The root state is basically
  378. # pinning the virtual "root" package in the graph.
  379. self._push_new_state()
  380. # Variables for optimistic backjumping
  381. optimistic_rounds_cutoff: int | None = None
  382. optimistic_backjumping_start_round: int | None = None
  383. for round_index in range(max_rounds):
  384. self._r.starting_round(index=round_index)
  385. # Handle if optimistic backjumping has been running for too long
  386. if self._optimistic_backjumping_ratio and self._save_states is not None:
  387. if optimistic_backjumping_start_round is None:
  388. optimistic_backjumping_start_round = round_index
  389. optimistic_rounds_cutoff = int(
  390. (max_rounds - round_index) * self._optimistic_backjumping_ratio
  391. )
  392. if optimistic_rounds_cutoff <= 0:
  393. self._rollback_states()
  394. continue
  395. elif optimistic_rounds_cutoff is not None:
  396. if (
  397. round_index - optimistic_backjumping_start_round
  398. >= optimistic_rounds_cutoff
  399. ):
  400. self._rollback_states()
  401. continue
  402. unsatisfied_names = [
  403. key
  404. for key, criterion in self.state.criteria.items()
  405. if not self._is_current_pin_satisfying(key, criterion)
  406. ]
  407. # All criteria are accounted for. Nothing more to pin, we are done!
  408. if not unsatisfied_names:
  409. self._r.ending(state=self.state)
  410. return self.state
  411. # keep track of satisfied names to calculate diff after pinning
  412. satisfied_names = set(self.state.criteria.keys()) - set(unsatisfied_names)
  413. if len(unsatisfied_names) > 1:
  414. narrowed_unstatisfied_names = list(
  415. self._p.narrow_requirement_selection(
  416. identifiers=unsatisfied_names,
  417. resolutions=self.state.mapping,
  418. candidates=IteratorMapping(
  419. self.state.criteria,
  420. operator.attrgetter("candidates"),
  421. ),
  422. information=IteratorMapping(
  423. self.state.criteria,
  424. operator.attrgetter("information"),
  425. ),
  426. backtrack_causes=self.state.backtrack_causes,
  427. )
  428. )
  429. else:
  430. narrowed_unstatisfied_names = unsatisfied_names
  431. # If there are no unsatisfied names use unsatisfied names
  432. if not narrowed_unstatisfied_names:
  433. raise RuntimeError("narrow_requirement_selection returned 0 names")
  434. # If there is only 1 unsatisfied name skip calling self._get_preference
  435. if len(narrowed_unstatisfied_names) > 1:
  436. # Choose the most preferred unpinned criterion to try.
  437. name = min(narrowed_unstatisfied_names, key=self._get_preference)
  438. else:
  439. name = narrowed_unstatisfied_names[0]
  440. failure_criterion = self._attempt_to_pin_criterion(name)
  441. if failure_criterion:
  442. causes = self._extract_causes(failure_criterion)
  443. # Backjump if pinning fails. The backjump process puts us in
  444. # an unpinned state, so we can work on it in the next round.
  445. self._r.resolving_conflicts(causes=causes)
  446. try:
  447. success = self._backjump(causes)
  448. except ResolutionImpossible:
  449. if self._optimistic_backjumping_ratio and self._save_states:
  450. failed_optimistic_backjumping = True
  451. else:
  452. raise
  453. else:
  454. failed_optimistic_backjumping = bool(
  455. not success
  456. and self._optimistic_backjumping_ratio
  457. and self._save_states
  458. )
  459. if failed_optimistic_backjumping and self._save_states:
  460. self._rollback_states()
  461. else:
  462. self.state.backtrack_causes[:] = causes
  463. # Dead ends everywhere. Give up.
  464. if not success:
  465. raise ResolutionImpossible(self.state.backtrack_causes)
  466. else:
  467. # discard as information sources any invalidated names
  468. # (unsatisfied names that were previously satisfied)
  469. newly_unsatisfied_names = {
  470. key
  471. for key, criterion in self.state.criteria.items()
  472. if key in satisfied_names
  473. and not self._is_current_pin_satisfying(key, criterion)
  474. }
  475. self._remove_information_from_criteria(
  476. self.state.criteria, newly_unsatisfied_names
  477. )
  478. # Pinning was successful. Push a new state to do another pin.
  479. self._push_new_state()
  480. self._r.ending_round(index=round_index, state=self.state)
  481. raise ResolutionTooDeep(max_rounds)
  482. class Resolver(AbstractResolver[RT, CT, KT]):
  483. """The thing that performs the actual resolution work."""
  484. base_exception = ResolverException
  485. def resolve( # type: ignore[override]
  486. self,
  487. requirements: Iterable[RT],
  488. max_rounds: int = 100,
  489. ) -> Result[RT, CT, KT]:
  490. """Take a collection of constraints, spit out the resolution result.
  491. The return value is a representation to the final resolution result. It
  492. is a tuple subclass with three public members:
  493. * `mapping`: A dict of resolved candidates. Each key is an identifier
  494. of a requirement (as returned by the provider's `identify` method),
  495. and the value is the resolved candidate.
  496. * `graph`: A `DirectedGraph` instance representing the dependency tree.
  497. The vertices are keys of `mapping`, and each edge represents *why*
  498. a particular package is included. A special vertex `None` is
  499. included to represent parents of user-supplied requirements.
  500. * `criteria`: A dict of "criteria" that hold detailed information on
  501. how edges in the graph are derived. Each key is an identifier of a
  502. requirement, and the value is a `Criterion` instance.
  503. The following exceptions may be raised if a resolution cannot be found:
  504. * `ResolutionImpossible`: A resolution cannot be found for the given
  505. combination of requirements. The `causes` attribute of the
  506. exception is a list of (requirement, parent), giving the
  507. requirements that could not be satisfied.
  508. * `ResolutionTooDeep`: The dependency tree is too deeply nested and
  509. the resolver gave up. This is usually caused by a circular
  510. dependency, but you can try to resolve this by increasing the
  511. `max_rounds` argument.
  512. """
  513. resolution = Resolution(self.provider, self.reporter)
  514. state = resolution.resolve(requirements, max_rounds=max_rounds)
  515. return _build_result(state)
  516. def _has_route_to_root(
  517. criteria: Mapping[KT, Criterion[RT, CT]],
  518. key: KT | None,
  519. all_keys: dict[int, KT | None],
  520. connected: set[KT | None],
  521. ) -> bool:
  522. if key in connected:
  523. return True
  524. if key not in criteria:
  525. return False
  526. assert key is not None
  527. for p in criteria[key].iter_parent():
  528. try:
  529. pkey = all_keys[id(p)]
  530. except KeyError:
  531. continue
  532. if pkey in connected:
  533. connected.add(key)
  534. return True
  535. if _has_route_to_root(criteria, pkey, all_keys, connected):
  536. connected.add(key)
  537. return True
  538. return False