diff --git a/NEWS.rst b/NEWS.rst index a49f5ec9..0f1b63a4 100644 --- a/NEWS.rst +++ b/NEWS.rst @@ -1,3 +1,12 @@ +v8.1.0 +====== + +Features +-------- + +- Prioritize valid dists to invalid dists when retrieving by name. (#489) + + v8.0.0 ====== diff --git a/importlib_metadata/__init__.py b/importlib_metadata/__init__.py index ed481355..b9fc04f1 100644 --- a/importlib_metadata/__init__.py +++ b/importlib_metadata/__init__.py @@ -25,7 +25,7 @@ install, ) from ._functools import method_cache, pass_none -from ._itertools import always_iterable, unique_everseen +from ._itertools import always_iterable, bucket, unique_everseen from ._meta import PackageMetadata, SimplePath from contextlib import suppress @@ -388,7 +388,7 @@ def from_name(cls, name: str) -> Distribution: if not name: raise ValueError("A distribution name is required.") try: - return next(iter(cls.discover(name=name))) + return next(iter(cls._prefer_valid(cls.discover(name=name)))) except StopIteration: raise PackageNotFoundError(name) @@ -412,6 +412,16 @@ def discover( resolver(context) for resolver in cls._discover_resolvers() ) + @staticmethod + def _prefer_valid(dists: Iterable[Distribution]) -> Iterable[Distribution]: + """ + Prefer (move to the front) distributions that have metadata. + + Ref python/importlib_resources#489. + """ + buckets = bucket(dists, lambda dist: bool(dist.metadata)) + return itertools.chain(buckets[True], buckets[False]) + @staticmethod def at(path: str | os.PathLike[str]) -> Distribution: """Return a Distribution for the indicated metadata path. diff --git a/importlib_metadata/_itertools.py b/importlib_metadata/_itertools.py index d4ca9b91..79d37198 100644 --- a/importlib_metadata/_itertools.py +++ b/importlib_metadata/_itertools.py @@ -1,3 +1,4 @@ +from collections import defaultdict, deque from itertools import filterfalse @@ -71,3 +72,100 @@ def always_iterable(obj, base_type=(str, bytes)): return iter(obj) except TypeError: return iter((obj,)) + + +# Copied from more_itertools 10.3 +class bucket: + """Wrap *iterable* and return an object that buckets the iterable into + child iterables based on a *key* function. + + >>> iterable = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2', 'b3'] + >>> s = bucket(iterable, key=lambda x: x[0]) # Bucket by 1st character + >>> sorted(list(s)) # Get the keys + ['a', 'b', 'c'] + >>> a_iterable = s['a'] + >>> next(a_iterable) + 'a1' + >>> next(a_iterable) + 'a2' + >>> list(s['b']) + ['b1', 'b2', 'b3'] + + The original iterable will be advanced and its items will be cached until + they are used by the child iterables. This may require significant storage. + + By default, attempting to select a bucket to which no items belong will + exhaust the iterable and cache all values. + If you specify a *validator* function, selected buckets will instead be + checked against it. + + >>> from itertools import count + >>> it = count(1, 2) # Infinite sequence of odd numbers + >>> key = lambda x: x % 10 # Bucket by last digit + >>> validator = lambda x: x in {1, 3, 5, 7, 9} # Odd digits only + >>> s = bucket(it, key=key, validator=validator) + >>> 2 in s + False + >>> list(s[2]) + [] + + """ + + def __init__(self, iterable, key, validator=None): + self._it = iter(iterable) + self._key = key + self._cache = defaultdict(deque) + self._validator = validator or (lambda x: True) + + def __contains__(self, value): + if not self._validator(value): + return False + + try: + item = next(self[value]) + except StopIteration: + return False + else: + self._cache[value].appendleft(item) + + return True + + def _get_values(self, value): + """ + Helper to yield items from the parent iterator that match *value*. + Items that don't match are stored in the local cache as they + are encountered. + """ + while True: + # If we've cached some items that match the target value, emit + # the first one and evict it from the cache. + if self._cache[value]: + yield self._cache[value].popleft() + # Otherwise we need to advance the parent iterator to search for + # a matching item, caching the rest. + else: + while True: + try: + item = next(self._it) + except StopIteration: + return + item_value = self._key(item) + if item_value == value: + yield item + break + elif self._validator(item_value): + self._cache[item_value].append(item) + + def __iter__(self): + for item in self._it: + item_value = self._key(item) + if self._validator(item_value): + self._cache[item_value].append(item) + + yield from self._cache.keys() + + def __getitem__(self, value): + if not self._validator(value): + return iter(()) + + return self._get_values(value) diff --git a/pyproject.toml b/pyproject.toml index 8cf8aeb9..24ce25e3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ test = [ "pytest-cov", "pytest-mypy", "pytest-enabler >= 2.2", - "pytest-ruff >= 0.2.1", + "pytest-ruff >= 0.2.1; sys_platform != 'cygwin'", # local 'importlib_resources>=1.3; python_version < "3.9"', diff --git a/ruff.toml b/ruff.toml index 70612985..922aa1f1 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,6 +1,7 @@ [lint] extend-select = [ "C901", + "PERF401", "W", ] ignore = [ @@ -22,7 +23,8 @@ ignore = [ ] [format] -# Enable preview, required for quote-style = "preserve" +# Enable preview to get hugged parenthesis unwrapping and other nice surprises +# See https://github.com/jaraco/skeleton/pull/133#issuecomment-2239538373 preview = true -# https://docs.astral.sh/ruff/settings/#format-quote-style +# https://docs.astral.sh/ruff/settings/#format_quote-style quote-style = "preserve" diff --git a/tests/test_main.py b/tests/test_main.py index f1c12855..dc248492 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -130,6 +130,31 @@ def test_unique_distributions(self): assert len(after) == len(before) +class InvalidMetadataTests(fixtures.OnSysPath, fixtures.SiteDir, unittest.TestCase): + @staticmethod + def make_pkg(name, files=dict(METADATA="VERSION: 1.0")): + """ + Create metadata for a dist-info package with name and files. + """ + return { + f'{name}.dist-info': files, + } + + def test_valid_dists_preferred(self): + """ + Dists with metadata should be preferred when discovered by name. + + Ref python/importlib_metadata#489. + """ + # create three dists with the valid one in the middle (lexicographically) + # such that on most file systems, the valid one is never naturally first. + fixtures.build_files(self.make_pkg('foo-4.0', files={}), self.site_dir) + fixtures.build_files(self.make_pkg('foo-4.1'), self.site_dir) + fixtures.build_files(self.make_pkg('foo-4.2', files={}), self.site_dir) + dist = Distribution.from_name('foo') + assert dist.version == "1.0" + + class NonASCIITests(fixtures.OnSysPath, fixtures.SiteDir, unittest.TestCase): @staticmethod def pkg_with_non_ascii_description(site_dir):
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: