Skip to content

Commit

Permalink
fix typing
Browse files Browse the repository at this point in the history
Signed-off-by: Frost Ming <[email protected]>
  • Loading branch information
frostming committed Sep 12, 2023
1 parent b27f5f4 commit 40ccfdd
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 6 deletions.
8 changes: 6 additions & 2 deletions src/resolvelib/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from .structs import CT, KT, RT, Matches, RequirementInformation

if TYPE_CHECKING:
from _typeshed import SupportsRichComparison
from typing import Any, Protocol

class Preference(Protocol):
def __lt__(self, __other: Any) -> bool:
...


class AbstractProvider(Generic[RT, CT, KT]):
Expand All @@ -34,7 +38,7 @@ def get_preference(
candidates: Mapping[KT, Iterator[CT]],
information: Mapping[KT, Iterator[RequirementInformation[RT, CT]]],
backtrack_causes: Sequence[RequirementInformation[RT, CT]],
) -> SupportsRichComparison:
) -> Preference:
"""Produce a sort key for given requirement based on preference.
The preference is defined as "I think this requirement should be
Expand Down
4 changes: 2 additions & 2 deletions src/resolvelib/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)

if TYPE_CHECKING:
from _typeshed import SupportsRichComparison
from .providers import Preference

class Result(NamedTuple, Generic[RT, CT, KT]):
mapping: Mapping[KT, CT]
Expand Down Expand Up @@ -197,7 +197,7 @@ def _remove_information_from_criteria(
criterion.incompatibilities,
)

def _get_preference(self, name: KT) -> SupportsRichComparison:
def _get_preference(self, name: KT) -> Preference:
return self._p.get_preference(
identifier=name,
resolutions=self.state.mapping,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def test_pin_conflict_with_self(monkeypatch, reporter):
],
}

class Provider(AbstractProvider[str, str, Candidate]):
class Provider(AbstractProvider[str, Candidate, str]):
def identify(self, requirement_or_candidate: str | Candidate) -> str:
result = (
Requirement(requirement_or_candidate).name
Expand Down Expand Up @@ -260,7 +260,7 @@ def get_updated_criteria_patch(self, candidate):
Resolution, "_get_updated_criteria", get_updated_criteria_patch
)

resolver: Resolver[str, str, Candidate] = Resolver(Provider(), reporter)
resolver: Resolver[str, Candidate, str] = Resolver(Provider(), reporter)
result = resolver.resolve(["child", "parent"])

def get_child_versions(
Expand Down

0 comments on commit 40ccfdd

Please sign in to comment.