Skip to content

Commit 72d7635

Browse files
authored
more reduction tests (xarray-contrib#11)
* boolean reduce * order reduce * index ordering tests * cumulative reduce tests `cumprod` / `cumulative_prod` are not part of the array API * skip `cumprod` since it's most likely not implemented * fix the expected value for cumulative reductions `Variable` implements n-d cum ops by iterating over the axes. * pin `numpy<2.1` until `xarray` issues a new release * resolve the `argmin` / `argmax` warnings by opting into the new behavior * try installing the nightly version of `xarray` * print the type of `actual` if it didn't match * try the recommended syntax for `pip` installing deps * put the options on a single line * upgrade packages * install nightly `xarray` as a separate step
1 parent 3c092cb commit 72d7635

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

.github/workflows/ci.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@ jobs:
6565
python=${{matrix.python-version}}
6666
conda
6767
68+
- name: Install nightly xarray
69+
run: |
70+
python -m pip install --upgrade --pre -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple xarray
71+
6872
- name: Install xarray-array-testing
6973
run: |
7074
python -m pip install --no-deps -e .

xarray_array_testing/reduction.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from contextlib import nullcontext
22

33
import hypothesis.strategies as st
4+
import numpy as np
45
import pytest
56
import xarray.testing.strategies as xrst
67
from hypothesis import given
@@ -24,4 +25,76 @@ def test_variable_numerical_reduce(self, op, data):
2425
# compute using xp.<OP>(array)
2526
expected = getattr(self.xp, op)(variable.data)
2627

28+
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
29+
self.assert_equal(actual, expected)
30+
31+
@pytest.mark.parametrize("op", ["all", "any"])
32+
@given(st.data())
33+
def test_variable_boolean_reduce(self, op, data):
34+
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
35+
36+
with self.expected_errors(op, variable=variable):
37+
# compute using xr.Variable.<OP>()
38+
actual = getattr(variable, op)().data
39+
# compute using xp.<OP>(array)
40+
expected = getattr(self.xp, op)(variable.data)
41+
42+
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
43+
self.assert_equal(actual, expected)
44+
45+
@pytest.mark.parametrize("op", ["max", "min"])
46+
@given(st.data())
47+
def test_variable_order_reduce(self, op, data):
48+
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
49+
50+
with self.expected_errors(op, variable=variable):
51+
# compute using xr.Variable.<OP>()
52+
actual = getattr(variable, op)().data
53+
# compute using xp.<OP>(array)
54+
expected = getattr(self.xp, op)(variable.data)
55+
56+
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
57+
self.assert_equal(actual, expected)
58+
59+
@pytest.mark.parametrize("op", ["argmax", "argmin"])
60+
@given(st.data())
61+
def test_variable_order_reduce_index(self, op, data):
62+
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
63+
64+
with self.expected_errors(op, variable=variable):
65+
# compute using xr.Variable.<OP>()
66+
actual = {k: v.item() for k, v in getattr(variable, op)(dim=...).items()}
67+
68+
# compute using xp.<OP>(array)
69+
index = getattr(self.xp, op)(variable.data)
70+
unraveled = np.unravel_index(index, variable.shape)
71+
expected = dict(zip(variable.dims, unraveled))
72+
73+
self.assert_equal(actual, expected)
74+
75+
@pytest.mark.parametrize(
76+
"op",
77+
[
78+
"cumsum",
79+
pytest.param(
80+
"cumprod",
81+
marks=pytest.mark.skip(reason="not yet included in the array api"),
82+
),
83+
],
84+
)
85+
@given(st.data())
86+
def test_variable_cumulative_reduce(self, op, data):
87+
array_api_names = {"cumsum": "cumulative_sum", "cumprod": "cumulative_prod"}
88+
variable = data.draw(xrst.variables(array_strategy_fn=self.array_strategy_fn))
89+
90+
with self.expected_errors(op, variable=variable):
91+
# compute using xr.Variable.<OP>()
92+
actual = getattr(variable, op)().data
93+
# compute using xp.<OP>(array)
94+
# Variable implements n-d cumulative ops by iterating over dims
95+
expected = variable.data
96+
for axis in range(variable.ndim):
97+
expected = getattr(self.xp, array_api_names[op])(expected, axis=axis)
98+
99+
assert isinstance(actual, self.array_type), f"wrong type: {type(actual)}"
27100
self.assert_equal(actual, expected)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy