Skip to content

Commit

Permalink
feat: Allow provider to filter unsatisfied names, when backtracking (#…
Browse files Browse the repository at this point in the history
…145)

* Allow provider to narrow backtrack selection

* formatting

* Throw specific error if narrowed_unstatisfied_names is empty

* Increase mccabe complexity

* Update docs

* Add functional tests for narrow_requirement_selection

* Add news entry

* update docs of `get_preference`
  • Loading branch information
notatallshaw authored Aug 9, 2024
1 parent b45601d commit 16d606d
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 8 deletions.
3 changes: 3 additions & 0 deletions news/145.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
New `narrow_requirement_selection` provider method giving option for
providers to reduce the number of times sort key `get_preference` is
called in long running backtrack
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ exclude = [
"*.pyi"
]

[tool.ruff.lint.mccabe]
max-complexity = 12

[tool.mypy]
warn_unused_configs = true

Expand Down
59 changes: 59 additions & 0 deletions src/resolvelib/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ def get_preference(
) -> Preference:
"""Produce a sort key for given requirement based on preference.
As this is a sort key it will be called O(n) times per backtrack
step, where n is the number of `identifier`s, if you have a check
which is expensive in some sense. E.g. It needs to make O(n) checks
per call or takes significant wall clock time, consider using
`narrow_requirement_selection` to filter the `identifier`s, which
is applied before this sort key is called.
The preference is defined as "I think this requirement should be
resolved first". The lower the return value is, the more preferred
this group of arguments is.
Expand Down Expand Up @@ -135,3 +142,55 @@ def get_dependencies(self, candidate: CT) -> Iterable[RT]:
specifies as its dependencies.
"""
raise NotImplementedError

def narrow_requirement_selection(
self,
identifiers: Iterable[KT],
resolutions: Mapping[KT, CT],
candidates: Mapping[KT, Iterator[CT]],
information: Mapping[KT, Iterator[RequirementInformation[RT, CT]]],
backtrack_causes: Sequence[RequirementInformation[RT, CT]],
) -> Iterable[KT]:
"""
An optional method to narrow the selection of requirements being
considered during resolution. This method is called O(1) time per
backtrack step.
:param identifiers: An iterable of `identifiers` as returned by
``identify()``. These identify all requirements currently being
considered.
:param resolutions: A mapping of candidates currently pinned by the
resolver. Each key is an identifier, and the value is a candidate
that may conflict with requirements from ``information``.
:param candidates: A mapping of each dependency's possible candidates.
Each value is an iterator of candidates.
:param information: A mapping of requirement information for each package.
Each value is an iterator of *requirement information*.
:param backtrack_causes: A sequence of *requirement information* that are
the requirements causing the resolver to most recently
backtrack.
A *requirement information* instance is a named tuple with two members:
* ``requirement`` specifies a requirement contributing to the current
list of candidates.
* ``parent`` specifies the candidate that provides (is depended on for)
the requirement, or ``None`` to indicate a root requirement.
Must return a non-empty subset of `identifiers`, with the default
implementation being to return `identifiers` unchanged. Those `identifiers`
will then be passed to the sort key `get_preference` to pick the most
prefered requirement to attempt to pin, unless `narrow_requirement_selection`
returns only 1 requirement, in which case that will be used without
calling the sort key `get_preference`.
This method is designed to be used by the provider to optimize the
dependency resolution, e.g. if a check cost is O(m) and it can be done
against all identifiers at once then filtering the requirement selection
here will cost O(m) but making it part of the sort key in `get_preference`
will cost O(m*n), where n is the number of `identifiers`.
Returns:
Iterable[KT]: A non-empty subset of `identifiers`.
"""
return identifiers
32 changes: 30 additions & 2 deletions src/resolvelib/resolvers/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,36 @@ def resolve(self, requirements: Iterable[RT], max_rounds: int) -> State[RT, CT,
# keep track of satisfied names to calculate diff after pinning
satisfied_names = set(self.state.criteria.keys()) - set(unsatisfied_names)

# Choose the most preferred unpinned criterion to try.
name = min(unsatisfied_names, key=self._get_preference)
if len(unsatisfied_names) > 1:
narrowed_unstatisfied_names = list(
self._p.narrow_requirement_selection(
identifiers=unsatisfied_names,
resolutions=self.state.mapping,
candidates=IteratorMapping(
self.state.criteria,
operator.attrgetter("candidates"),
),
information=IteratorMapping(
self.state.criteria,
operator.attrgetter("information"),
),
backtrack_causes=self.state.backtrack_causes,
)
)
else:
narrowed_unstatisfied_names = unsatisfied_names

# If there are no unsatisfied names use unsatisfied names
if not narrowed_unstatisfied_names:
raise RuntimeError("narrow_requirement_selection returned 0 names")

# If there is only 1 unsatisfied name skip calling self._get_preference
if len(narrowed_unstatisfied_names) > 1:
# Choose the most preferred unpinned criterion to try.
name = min(narrowed_unstatisfied_names, key=self._get_preference)
else:
name = narrowed_unstatisfied_names[0]

failure_criterion = self._attempt_to_pin_criterion(name)

if failure_criterion:
Expand Down
42 changes: 36 additions & 6 deletions tests/functional/python/test_resolvers_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,24 @@ def get_dependencies(self, candidate):
return list(self._iter_dependencies(candidate))


class PythonInputProviderNarrowRequirements(PythonInputProvider):
def narrow_requirement_selection(
self, identifiers, resolutions, candidates, information, backtrack_causes
):
# Consider requirements that have 0 candidates (a resolution end point
# that can be backtracked from) or 1 candidate (speeds up situations where
# ever requirement is pinned to 1 specific version)
number_of_candidates = defaultdict(list)
for identifier in identifiers:
number_of_candidates[len(list(candidates[identifier]))].append(identifier)

min_candidates = min(number_of_candidates.keys())
if min_candidates in (0, 1):
return number_of_candidates[min_candidates]

return identifiers


INPUTS_DIR = os.path.abspath(os.path.join(__file__, "..", "inputs"))

CASE_DIR = os.path.join(INPUTS_DIR, "case")
Expand All @@ -133,20 +151,32 @@ def get_dependencies(self, candidate):
}


@pytest.fixture(
params=[
def create_params(provider_class):
return [
pytest.param(
os.path.join(CASE_DIR, n),
(os.path.join(CASE_DIR, n), provider_class),
marks=pytest.mark.xfail(strict=True, reason=XFAIL_CASES[n]),
)
if n in XFAIL_CASES
else os.path.join(CASE_DIR, n)
else (os.path.join(CASE_DIR, n), provider_class)
for n in CASE_NAMES
]


@pytest.fixture(
params=[
*create_params(PythonInputProvider),
*create_params(PythonInputProviderNarrowRequirements),
],
ids=[
f"{n[:-5]}-{cls.__name__}"
for cls in [PythonInputProvider, PythonInputProviderNarrowRequirements]
for n in CASE_NAMES
],
ids=[n[:-5] for n in CASE_NAMES],
)
def provider(request):
return PythonInputProvider(request.param)
path, provider_class = request.param
return provider_class(path)


def _format_confliction(exception):
Expand Down

0 comments on commit 16d606d

Please sign in to comment.