Skip to content

ENH: Enable custom compression levels in np.savez_compressed #29294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 100 additions & 12 deletions numpy/lib/_npyio_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,14 @@ def zipfile_factory(file, *args, **kwargs):
if not hasattr(file, 'read'):
file = os.fspath(file)
import zipfile
kwargs['allowZip64'] = True
return zipfile.ZipFile(file, *args, **kwargs)

# Handle compression parameters
compresslevel = kwargs.pop('compresslevel', None)
compression = kwargs.pop('compression', zipfile.ZIP_STORED)

# Modern Python versions support compresslevel
return zipfile.ZipFile(file, *args, compression=compression,
compresslevel=compresslevel, **kwargs)


@set_module('numpy.lib.npyio')
Expand Down Expand Up @@ -686,13 +692,15 @@ def savez(file, *args, allow_pickle=True, **kwds):
_savez(file, args, kwds, False, allow_pickle=allow_pickle)


def _savez_compressed_dispatcher(file, *args, allow_pickle=True, **kwds):
def _savez_compressed_dispatcher(file, *args, allow_pickle=True,
zipfile_kwargs=None, **kwds):
yield from args
yield from kwds.values()


@array_function_dispatch(_savez_compressed_dispatcher)
def savez_compressed(file, *args, allow_pickle=True, **kwds):
def savez_compressed(file, *args, allow_pickle=True,
zipfile_kwargs=None, **kwds):
"""
Save several arrays into a single file in compressed ``.npz`` format.

Expand Down Expand Up @@ -721,6 +729,10 @@ def savez_compressed(file, *args, allow_pickle=True, **kwds):
require libraries that are not available, and not all pickled data is
compatible between different versions of Python).
Default: True
zipfile_kwargs : dict, optional
Dictionary of keyword arguments forwarded directly to
``zipfile.ZipFile`` (e.g. ``compression``, ``compresslevel``).
By default, ``compression`` is set to ``ZIP_DEFLATED``.
kwds : Keyword arguments, optional
Arrays to save to the file. Each array will be saved to the
output file with its corresponding keyword name.
Expand Down Expand Up @@ -763,12 +775,12 @@ def savez_compressed(file, *args, allow_pickle=True, **kwds):
True

"""
_savez(file, args, kwds, True, allow_pickle=allow_pickle)
_savez(file, args, kwds, True, allow_pickle=allow_pickle,
zipfile_kwargs=zipfile_kwargs)


def _savez(file, args, kwds, compress, allow_pickle=True, pickle_kwargs=None):
# Import is postponed to here since zipfile depends on gzip, an optional
# component of the so-called standard library.
def _savez(file, args, kwds, compress, allow_pickle=True, pickle_kwargs=None,
zipfile_kwargs=None):
import zipfile

if not hasattr(file, 'write'):
Expand All @@ -784,12 +796,88 @@ def _savez(file, args, kwds, compress, allow_pickle=True, pickle_kwargs=None):
f"Cannot use un-named variables and keyword {key}")
namedict[key] = val

if compress:
compression = zipfile.ZIP_DEFLATED
# Prepare ZipFile keyword arguments
if zipfile_kwargs is None:
zipfile_kwargs = {}

# Default behaviour: use DEFLATED for the compressed variant, STORED
# otherwise – unless the user explicitly asked for something else.
comp = zipfile_kwargs.get("compression")
if comp is None:
comp = zipfile.ZIP_DEFLATED if compress else zipfile.ZIP_STORED
else:
compression = zipfile.ZIP_STORED
# Translate textual aliases such as ``"deflated"`` to their integer
# counterparts. Accepting the textual form mirrors the behaviour of
# ``zipfile.ZipFile`` and provides a more friendly public API.
if isinstance(comp, str):
_str_to_const = {
"stored": zipfile.ZIP_STORED,
"deflated": zipfile.ZIP_DEFLATED,
}
if hasattr(zipfile, "ZIP_BZIP2"):
_str_to_const["bzip2"] = zipfile.ZIP_BZIP2
if hasattr(zipfile, "ZIP_LZMA"):
_str_to_const["lzma"] = zipfile.ZIP_LZMA

key = comp.lower()
if key not in _str_to_const:
raise ValueError(
f"Unknown compression method: {comp!r}. "
f"Valid options: {list(_str_to_const)}"
)
comp = _str_to_const[key]
elif isinstance(comp, int):
# Verify that the provided integer constant is supported by the
# runtime Python build.
_valid_ints = {zipfile.ZIP_STORED, zipfile.ZIP_DEFLATED}
if hasattr(zipfile, "ZIP_BZIP2"):
_valid_ints.add(zipfile.ZIP_BZIP2)
if hasattr(zipfile, "ZIP_LZMA"):
_valid_ints.add(zipfile.ZIP_LZMA)

if comp not in _valid_ints:
raise ValueError(
f"Unknown compression method: {comp}. "
f"Valid options: {sorted(_valid_ints)}"
)
else:
raise TypeError(
"compression must be an int (zipfile constant) or a str "
"specifying the method"
)

# Persist the (possibly normalised) compression constant back into kwargs
zipfile_kwargs["compression"] = comp

# Validate ``compresslevel`` – ignore if the user passed ``None``.
cl = zipfile_kwargs.pop("compresslevel", None)
if cl is not None:
if not isinstance(cl, int):
raise ValueError("compresslevel must be an integer or None")

if comp == zipfile.ZIP_STORED:
raise ValueError(
"compresslevel is not applicable when using ZIP_STORED."
)

def _in_range(minv: int, maxv: int) -> bool:
return minv <= cl <= maxv

if comp == zipfile.ZIP_DEFLATED and not _in_range(0, 9):
raise ValueError("For DEFLATED, compresslevel must be between 0 and 9.")
if (hasattr(zipfile, "ZIP_BZIP2") and comp == zipfile.ZIP_BZIP2
and not _in_range(1, 9)):
raise ValueError("For BZIP2, compresslevel must be between 1 and 9.")
if (hasattr(zipfile, "ZIP_LZMA") and comp == zipfile.ZIP_LZMA
and not _in_range(0, 9)):
raise ValueError("For LZMA, compresslevel must be between 0 and 9.")

# Store the validated compresslevel back into kwargs
zipfile_kwargs["compresslevel"] = cl

# Create the ZipFile object
zipf = zipfile_factory(file, mode="w", **zipfile_kwargs)

zipf = zipfile_factory(file, mode="w", compression=compression)
try:
for key, val in namedict.items():
fname = key + '.npy'
Expand Down
Loading
Loading
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy