|
7 | 7 | import threading
|
8 | 8 | import time
|
9 | 9 | import warnings
|
| 10 | +import zipfile |
10 | 11 | from ctypes import c_bool
|
11 | 12 | from datetime import datetime
|
12 | 13 | from io import BytesIO, StringIO
|
@@ -217,6 +218,22 @@ def roundtrip(self, *args, **kwargs):
|
217 | 218 | self.arr_reloaded.fid.close()
|
218 | 219 | os.remove(self.arr_reloaded.fid.name)
|
219 | 220 |
|
| 221 | + def test_load_non_npy(self): |
| 222 | + """Test loading non-.npy files and name mapping in .npz.""" |
| 223 | + with temppath(prefix="numpy_test_npz_load_non_npy_", suffix=".npz") as tmp: |
| 224 | + with zipfile.ZipFile(tmp, "w") as npz: |
| 225 | + with npz.open("test1.npy", "w") as out_file: |
| 226 | + np.save(out_file, np.arange(10)) |
| 227 | + with npz.open("test2", "w") as out_file: |
| 228 | + np.save(out_file, np.arange(10)) |
| 229 | + with npz.open("metadata", "w") as out_file: |
| 230 | + out_file.write(b"Name: Test") |
| 231 | + with np.load(tmp) as npz: |
| 232 | + assert len(npz["test1"]) == 10 |
| 233 | + assert len(npz["test1.npy"]) == 10 |
| 234 | + assert len(npz["test2"]) == 10 |
| 235 | + assert npz["metadata"] == b"Name: Test" |
| 236 | + |
220 | 237 | @pytest.mark.skipif(IS_PYPY, reason="Hangs on PyPy")
|
221 | 238 | @pytest.mark.skipif(not IS_64BIT, reason="Needs 64bit platform")
|
222 | 239 | @pytest.mark.slow
|
|
0 commit comments