Skip to content

Commit 7ee140a

Browse files
authored
Merge branch 'main' into pre-commit-ci-update-config
2 parents 8e747e3 + 0f0563b commit 7ee140a

File tree

3 files changed

+101
-42
lines changed

3 files changed

+101
-42
lines changed

src/zarr/testing/stateful.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import builtins
2-
from typing import Any
2+
import functools
3+
from collections.abc import Callable
4+
from typing import Any, TypeVar, cast
35

46
import hypothesis.extra.numpy as npst
57
import hypothesis.strategies as st
@@ -24,15 +26,43 @@
2426
from zarr.testing.strategies import (
2527
basic_indices,
2628
chunk_paths,
29+
dimension_names,
2730
key_ranges,
2831
node_names,
2932
np_array_and_chunks,
30-
numpy_arrays,
33+
orthogonal_indices,
3134
)
3235
from zarr.testing.strategies import keys as zarr_keys
3336

3437
MAX_BINARY_SIZE = 100
3538

39+
F = TypeVar("F", bound=Callable[..., Any])
40+
41+
42+
def with_frequency(frequency: float) -> Callable[[F], F]:
43+
"""This needs to be deterministic for hypothesis replaying"""
44+
45+
def decorator(func: F) -> F:
46+
counter_attr = f"__{func.__name__}_counter"
47+
48+
@functools.wraps(func)
49+
def wrapper(*args: Any, **kwargs: Any) -> Any:
50+
return func(*args, **kwargs)
51+
52+
@precondition
53+
def frequency_check(f: Any) -> Any:
54+
if not hasattr(f, counter_attr):
55+
setattr(f, counter_attr, 0)
56+
57+
current_count = getattr(f, counter_attr) + 1
58+
setattr(f, counter_attr, current_count)
59+
60+
return (current_count * frequency) % 1.0 >= (1.0 - frequency)
61+
62+
return cast(F, frequency_check(wrapper))
63+
64+
return decorator
65+
3666

3767
def split_prefix_name(path: str) -> tuple[str, str]:
3868
split = path.rsplit("/", maxsplit=1)
@@ -90,11 +120,7 @@ def add_group(self, name: str, data: DataObject) -> None:
90120
zarr.group(store=self.store, path=path)
91121
zarr.group(store=self.model, path=path)
92122

93-
@rule(
94-
data=st.data(),
95-
name=node_names,
96-
array_and_chunks=np_array_and_chunks(arrays=numpy_arrays(zarr_formats=st.just(3))),
97-
)
123+
@rule(data=st.data(), name=node_names, array_and_chunks=np_array_and_chunks())
98124
def add_array(
99125
self,
100126
data: DataObject,
@@ -122,12 +148,17 @@ def add_array(
122148
path=path,
123149
store=store,
124150
fill_value=fill_value,
151+
zarr_format=3,
152+
dimension_names=data.draw(
153+
dimension_names(ndim=array.ndim), label="dimension names"
154+
),
125155
# Chose bytes codec to avoid wasting time compressing the data being written
126156
codecs=[BytesCodec()],
127157
)
128158
self.all_arrays.add(path)
129159

130160
@rule()
161+
@with_frequency(0.25)
131162
def clear(self) -> None:
132163
note("clearing")
133164
import zarr
@@ -192,6 +223,14 @@ def delete_chunk(self, data: DataObject) -> None:
192223
self._sync(self.model.delete(path))
193224
self._sync(self.store.delete(path))
194225

226+
@precondition(lambda self: bool(self.all_arrays))
227+
@rule(data=st.data())
228+
def check_array(self, data: DataObject) -> None:
229+
path = data.draw(st.sampled_from(sorted(self.all_arrays)))
230+
actual = zarr.open_array(self.store, path=path)[:]
231+
expected = zarr.open_array(self.model, path=path)[:]
232+
np.testing.assert_equal(actual, expected)
233+
195234
@precondition(lambda self: bool(self.all_arrays))
196235
@rule(data=st.data())
197236
def overwrite_array_basic_indexing(self, data: DataObject) -> None:
@@ -206,6 +245,20 @@ def overwrite_array_basic_indexing(self, data: DataObject) -> None:
206245
model_array[slicer] = new_data
207246
store_array[slicer] = new_data
208247

248+
@precondition(lambda self: bool(self.all_arrays))
249+
@rule(data=st.data())
250+
def overwrite_array_orthogonal_indexing(self, data: DataObject) -> None:
251+
array = data.draw(st.sampled_from(sorted(self.all_arrays)))
252+
model_array = zarr.open_array(path=array, store=self.model)
253+
store_array = zarr.open_array(path=array, store=self.store)
254+
indexer, _ = data.draw(orthogonal_indices(shape=model_array.shape))
255+
note(f"overwriting array orthogonal {indexer=}")
256+
new_data = data.draw(
257+
npst.arrays(shape=model_array.oindex[indexer].shape, dtype=model_array.dtype) # type: ignore[union-attr]
258+
)
259+
model_array.oindex[indexer] = new_data
260+
store_array.oindex[indexer] = new_data
261+
209262
@precondition(lambda self: bool(self.all_arrays))
210263
@rule(data=st.data())
211264
def resize_array(self, data: DataObject) -> None:

src/zarr/testing/strategies.py

Lines changed: 37 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def paths(draw: st.DrawFn, *, max_num_nodes: int | None = None) -> str:
4343
return draw(st.just("/") | keys(max_num_nodes=max_num_nodes))
4444

4545

46-
def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
46+
def dtypes() -> st.SearchStrategy[np.dtype[Any]]:
4747
return (
4848
npst.boolean_dtypes()
4949
| npst.integer_dtypes(endianness="=")
@@ -57,18 +57,12 @@ def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
5757
)
5858

5959

60+
def v3_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
61+
return dtypes()
62+
63+
6064
def v2_dtypes() -> st.SearchStrategy[np.dtype[Any]]:
61-
return (
62-
npst.boolean_dtypes()
63-
| npst.integer_dtypes(endianness="=")
64-
| npst.unsigned_integer_dtypes(endianness="=")
65-
| npst.floating_dtypes(endianness="=")
66-
| npst.complex_number_dtypes(endianness="=")
67-
| npst.byte_string_dtypes(endianness="=")
68-
| npst.unicode_string_dtypes(endianness="=")
69-
| npst.datetime64_dtypes(endianness="=")
70-
| npst.timedelta64_dtypes(endianness="=")
71-
)
65+
return dtypes()
7266

7367

7468
def safe_unicode_for_dtype(dtype: np.dtype[np.str_]) -> st.SearchStrategy[str]:
@@ -144,7 +138,7 @@ def array_metadata(
144138
shape = draw(array_shapes())
145139
ndim = len(shape)
146140
chunk_shape = draw(array_shapes(min_dims=ndim, max_dims=ndim))
147-
np_dtype = draw(v3_dtypes())
141+
np_dtype = draw(dtypes())
148142
dtype = get_data_type_from_native_dtype(np_dtype)
149143
fill_value = draw(npst.from_dtype(np_dtype))
150144
if zarr_format == 2:
@@ -179,14 +173,12 @@ def numpy_arrays(
179173
*,
180174
shapes: st.SearchStrategy[tuple[int, ...]] = array_shapes,
181175
dtype: np.dtype[Any] | None = None,
182-
zarr_formats: st.SearchStrategy[ZarrFormat] = zarr_formats,
183176
) -> npt.NDArray[Any]:
184177
"""
185178
Generate numpy arrays that can be saved in the provided Zarr format.
186179
"""
187-
zarr_format = draw(zarr_formats)
188180
if dtype is None:
189-
dtype = draw(v3_dtypes() if zarr_format == 3 else v2_dtypes())
181+
dtype = draw(dtypes())
190182
if np.issubdtype(dtype, np.str_):
191183
safe_unicode_strings = safe_unicode_for_dtype(dtype)
192184
return draw(npst.arrays(dtype=dtype, shape=shapes, elements=safe_unicode_strings))
@@ -255,17 +247,24 @@ def arrays(
255247
attrs: st.SearchStrategy = attrs,
256248
zarr_formats: st.SearchStrategy = zarr_formats,
257249
) -> Array:
258-
store = draw(stores)
259-
path = draw(paths)
260-
name = draw(array_names)
261-
attributes = draw(attrs)
262-
zarr_format = draw(zarr_formats)
250+
store = draw(stores, label="store")
251+
path = draw(paths, label="array parent")
252+
name = draw(array_names, label="array name")
253+
attributes = draw(attrs, label="attributes")
254+
zarr_format = draw(zarr_formats, label="zarr format")
263255
if arrays is None:
264-
arrays = numpy_arrays(shapes=shapes, zarr_formats=st.just(zarr_format))
265-
nparray = draw(arrays)
266-
chunk_shape = draw(chunk_shapes(shape=nparray.shape))
256+
arrays = numpy_arrays(shapes=shapes)
257+
nparray = draw(arrays, label="array data")
258+
chunk_shape = draw(chunk_shapes(shape=nparray.shape), label="chunk shape")
259+
extra_kwargs = {}
267260
if zarr_format == 3 and all(c > 0 for c in chunk_shape):
268-
shard_shape = draw(st.none() | shard_shapes(shape=nparray.shape, chunk_shape=chunk_shape))
261+
shard_shape = draw(
262+
st.none() | shard_shapes(shape=nparray.shape, chunk_shape=chunk_shape),
263+
label="shard shape",
264+
)
265+
extra_kwargs["dimension_names"] = draw(
266+
dimension_names(ndim=nparray.ndim), label="dimension names"
267+
)
269268
else:
270269
shard_shape = None
271270
# test that None works too.
@@ -286,6 +285,7 @@ def arrays(
286285
attributes=attributes,
287286
# compressor=compressor, # FIXME
288287
fill_value=fill_value,
288+
**extra_kwargs,
289289
)
290290

291291
assert isinstance(a, Array)
@@ -385,13 +385,19 @@ def orthogonal_indices(
385385
npindexer = []
386386
ndim = len(shape)
387387
for axis, size in enumerate(shape):
388-
val = draw(
389-
npst.integer_array_indices(
388+
if size != 0:
389+
strategy = npst.integer_array_indices(
390390
shape=(size,), result_shape=npst.array_shapes(min_side=1, max_side=size, max_dims=1)
391-
)
392-
| basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False)
393-
.map(lambda x: (x,) if not isinstance(x, tuple) else x) # bare ints, slices
394-
.filter(bool) # skip empty tuple
391+
) | basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False)
392+
else:
393+
strategy = basic_indices(min_dims=1, shape=(size,), allow_ellipsis=False)
394+
395+
val = draw(
396+
strategy
397+
# bare ints, slices
398+
.map(lambda x: (x,) if not isinstance(x, tuple) else x)
399+
# skip empty tuple
400+
.filter(bool)
395401
)
396402
(idxr,) = val
397403
if isinstance(idxr, int):

tests/test_properties.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ def deep_equal(a: Any, b: Any) -> bool:
7676

7777

7878
@pytest.mark.filterwarnings("ignore::zarr.core.dtype.common.UnstableSpecificationWarning")
79-
@given(data=st.data(), zarr_format=zarr_formats)
80-
def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None:
81-
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
82-
zarray = data.draw(arrays(arrays=st.just(nparray), zarr_formats=st.just(zarr_format)))
79+
@given(data=st.data())
80+
def test_array_roundtrip(data: st.DataObject) -> None:
81+
nparray = data.draw(numpy_arrays())
82+
zarray = data.draw(arrays(arrays=st.just(nparray)))
8383
assert_array_equal(nparray, zarray[:])
8484

8585

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy