Skip to content

Commit dc44205

Browse files
authored
Merge pull request #301 from crusaderky/test_assert_equal
TST: rework tests for `xp_assert_equal`
2 parents e4ecb82 + 11b535c commit dc44205

File tree

1 file changed

+132
-139
lines changed

1 file changed

+132
-139
lines changed

tests/test_testing.py

Lines changed: 132 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections.abc import Callable
2-
from contextlib import nullcontext
32
from types import ModuleType
43
from typing import cast
54

@@ -21,160 +20,154 @@
2120
from array_api_extra._lib._utils._typing import Array, Device
2221
from array_api_extra.testing import lazy_xp_function
2322

24-
# mypy: disable-error-code=decorated-any
23+
# mypy: disable-error-code="decorated-any, explicit-any"
2524
# pyright: reportUnknownParameterType=false,reportMissingParameterType=false
2625

27-
param_assert_equal_close = pytest.mark.parametrize(
28-
"func",
29-
[
30-
xp_assert_equal,
31-
xp_assert_less,
32-
pytest.param(
33-
xp_assert_close,
34-
marks=pytest.mark.xfail_xp_backend(
35-
Backend.SPARSE, reason="no isdtype", strict=False
36-
),
37-
),
38-
],
39-
)
40-
4126

4227
def test_as_numpy_array(xp: ModuleType, device: Device):
4328
x = xp.asarray([1, 2, 3], device=device)
4429
y = as_numpy_array(x, xp=xp)
4530
assert isinstance(y, np.ndarray)
4631

4732

48-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype", strict=False)
49-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
50-
def test_assert_close_equal_basic(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
51-
func(xp.asarray(0), xp.asarray(0))
52-
func(xp.asarray([1, 2]), xp.asarray([1, 2]))
53-
54-
with pytest.raises(AssertionError, match="shapes do not match"):
55-
func(xp.asarray([0]), xp.asarray([[0]]))
56-
57-
with pytest.raises(AssertionError, match="dtypes do not match"):
58-
func(xp.asarray(0, dtype=xp.float32), xp.asarray(0, dtype=xp.float64))
59-
60-
with pytest.raises(AssertionError):
61-
func(xp.asarray([1, 2]), xp.asarray([1, 3]))
62-
63-
with pytest.raises(AssertionError, match="hello"):
64-
func(xp.asarray([1, 2]), xp.asarray([1, 3]), err_msg="hello")
65-
66-
67-
@pytest.mark.skip_xp_backend(Backend.NUMPY, reason="test other ns vs. numpy")
68-
@pytest.mark.skip_xp_backend(Backend.NUMPY_READONLY, reason="test other ns vs. numpy")
69-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
70-
def test_assert_close_equal_less_namespace(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
71-
with pytest.raises(AssertionError, match="namespaces do not match"):
72-
func(xp.asarray(0), np.asarray(0))
73-
with pytest.raises(TypeError, match="Unrecognized array input"):
74-
func(xp.asarray(0), 0)
75-
with pytest.raises(TypeError, match="list is not a supported array type"):
76-
func(xp.asarray([0]), [0])
77-
78-
79-
@param_assert_equal_close
80-
@pytest.mark.parametrize("check_shape", [False, True])
81-
def test_assert_close_equal_less_shape( # type: ignore[explicit-any]
82-
xp: ModuleType,
83-
func: Callable[..., None],
84-
check_shape: bool,
85-
):
86-
context = (
87-
pytest.raises(AssertionError, match="shapes do not match")
88-
if check_shape
89-
else nullcontext()
90-
)
91-
with context:
92-
# note: NaNs are handled by all 3 checks
93-
func(xp.asarray([xp.nan, xp.nan]), xp.asarray(xp.nan), check_shape=check_shape)
94-
95-
96-
@param_assert_equal_close
97-
@pytest.mark.parametrize("check_dtype", [False, True])
98-
def test_assert_close_equal_less_dtype( # type: ignore[explicit-any]
99-
xp: ModuleType,
100-
func: Callable[..., None],
101-
check_dtype: bool,
102-
):
103-
context = (
104-
pytest.raises(AssertionError, match="dtypes do not match")
105-
if check_dtype
106-
else nullcontext()
107-
)
108-
with context:
109-
func(
110-
xp.asarray(xp.nan, dtype=xp.float32),
111-
xp.asarray(xp.nan, dtype=xp.float64),
112-
check_dtype=check_dtype,
113-
)
114-
115-
116-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
117-
@pytest.mark.parametrize("check_scalar", [False, True])
118-
def test_assert_close_equal_less_scalar( # type: ignore[explicit-any]
119-
xp: ModuleType,
120-
func: Callable[..., None],
121-
check_scalar: bool,
122-
):
123-
context = (
124-
pytest.raises(AssertionError, match="array-ness does not match")
125-
if check_scalar
126-
else nullcontext()
33+
class TestAssertEqualCloseLess:
34+
pr_assert_close = pytest.param( # pyright: ignore[reportUnannotatedClassAttribute]
35+
xp_assert_close,
36+
marks=pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype"),
12737
)
128-
with context:
129-
func(np.asarray(xp.nan), np.asarray(xp.nan)[()], check_scalar=check_scalar)
13038

39+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close])
40+
def test_assert_equal_close_basic(self, xp: ModuleType, func: Callable[..., None]):
41+
func(xp.asarray(0), xp.asarray(0))
42+
func(xp.asarray([1, 2]), xp.asarray([1, 2]))
13143

132-
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
133-
def test_assert_close_tolerance(xp: ModuleType):
134-
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.03)
135-
with pytest.raises(AssertionError):
136-
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), rtol=0.01)
44+
with pytest.raises(AssertionError, match="Mismatched elements"):
45+
func(xp.asarray([1, 2]), xp.asarray([2, 1]))
13746

138-
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=3)
139-
with pytest.raises(AssertionError):
140-
xp_assert_close(xp.asarray([100.0]), xp.asarray([102.0]), atol=1)
47+
with pytest.raises(AssertionError, match="hello"):
48+
func(xp.asarray([1, 2]), xp.asarray([2, 1]), err_msg="hello")
14149

50+
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
51+
def test_shape_dtype(self, xp: ModuleType, func: Callable[..., None]):
52+
with pytest.raises(AssertionError, match="shapes do not match"):
53+
func(xp.asarray([0]), xp.asarray([[0]]))
14254

143-
def test_assert_less_basic(xp: ModuleType):
144-
xp_assert_less(xp.asarray(-1), xp.asarray(0))
145-
xp_assert_less(xp.asarray([1, 2]), xp.asarray([2, 3]))
146-
with pytest.raises(AssertionError):
147-
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]))
148-
with pytest.raises(AssertionError, match="hello"):
149-
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]), err_msg="hello")
55+
with pytest.raises(AssertionError, match="dtypes do not match"):
56+
func(xp.asarray(0, dtype=xp.float32), xp.asarray(0, dtype=xp.float64))
15057

151-
152-
@pytest.mark.skip_xp_backend(Backend.SPARSE, reason="index by sparse array")
153-
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
154-
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close])
155-
def test_assert_close_equal_none_shape(xp: ModuleType, func: Callable[..., None]): # type: ignore[explicit-any]
156-
"""On Dask and other lazy backends, test that a shape with NaN's or None's
157-
can be compared to a real shape.
158-
"""
159-
a = xp.asarray([1, 2])
160-
a = a[a > 1]
161-
162-
func(a, xp.asarray([2]))
163-
with pytest.raises(AssertionError):
164-
func(a, xp.asarray([2, 3]))
165-
with pytest.raises(AssertionError):
166-
func(a, xp.asarray(2))
167-
with pytest.raises(AssertionError):
168-
func(a, xp.asarray([3]))
169-
170-
# Swap actual and desired
171-
func(xp.asarray([2]), a)
172-
with pytest.raises(AssertionError):
173-
func(xp.asarray([2, 3]), a)
174-
with pytest.raises(AssertionError):
175-
func(xp.asarray(2), a)
176-
with pytest.raises(AssertionError):
177-
func(xp.asarray([3]), a)
58+
@pytest.mark.skip_xp_backend(Backend.NUMPY, reason="test other ns vs. numpy")
59+
@pytest.mark.skip_xp_backend(
60+
Backend.NUMPY_READONLY, reason="test other ns vs. numpy"
61+
)
62+
@pytest.mark.parametrize("func", [xp_assert_equal, xp_assert_close, xp_assert_less])
63+
def test_namespace(self, xp: ModuleType, func: Callable[..., None]):
64+
with pytest.raises(AssertionError, match="namespaces do not match"):
65+
func(xp.asarray(0), np.asarray(0))
66+
with pytest.raises(TypeError, match="Unrecognized array input"):
67+
func(xp.asarray(0), 0)
68+
with pytest.raises(TypeError, match="list is not a supported array type"):
69+
func(xp.asarray([0]), [0])
70+
71+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
72+
def test_check_shape(self, xp: ModuleType, func: Callable[..., None]):
73+
a = xp.asarray([1] if func is xp_assert_less else [2])
74+
b = xp.asarray(2)
75+
c = xp.asarray(0)
76+
d = xp.asarray([2, 2])
77+
78+
with pytest.raises(AssertionError, match="shapes do not match"):
79+
func(a, b)
80+
func(a, b, check_shape=False)
81+
with pytest.raises(AssertionError, match="Mismatched elements"):
82+
func(a, c, check_shape=False)
83+
with pytest.raises(AssertionError, match=r"shapes \(1,\), \(2,\) mismatch"):
84+
func(a, d, check_shape=False)
85+
86+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
87+
def test_check_dtype(self, xp: ModuleType, func: Callable[..., None]):
88+
a = xp.asarray(1 if func is xp_assert_less else 2)
89+
b = xp.asarray(2, dtype=xp.int16)
90+
c = xp.asarray(0, dtype=xp.int16)
91+
92+
with pytest.raises(AssertionError, match="dtypes do not match"):
93+
func(a, b)
94+
func(a, b, check_dtype=False)
95+
with pytest.raises(AssertionError, match="Mismatched elements"):
96+
func(a, c, check_dtype=False)
97+
98+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
99+
@pytest.mark.xfail_xp_backend(
100+
Backend.SPARSE, reason="sparse [()] returns np.generic"
101+
)
102+
def test_check_scalar(
103+
self, xp: ModuleType, library: Backend, func: Callable[..., None]
104+
):
105+
a = xp.asarray(1 if func is xp_assert_less else 2)
106+
b = xp.asarray(2)[()] # Note: only makes a difference on NumPy
107+
c = xp.asarray(0)
108+
109+
func(a, b)
110+
if library.like(Backend.NUMPY):
111+
with pytest.raises(AssertionError, match="array-ness does not match"):
112+
func(a, b, check_scalar=True)
113+
else:
114+
func(a, b, check_scalar=True)
115+
with pytest.raises(AssertionError, match="Mismatched elements"):
116+
func(a, c, check_scalar=True)
117+
118+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no isdtype")
119+
@pytest.mark.parametrize("dtype", ["int64", "float64"])
120+
def test_assert_close_tolerance(self, dtype: str, xp: ModuleType):
121+
a = xp.asarray([100], dtype=getattr(xp, dtype))
122+
b = xp.asarray([102], dtype=getattr(xp, dtype))
123+
124+
with pytest.raises(AssertionError, match="Mismatched elements"):
125+
xp_assert_close(a, b)
126+
127+
xp_assert_close(a, b, rtol=0.03)
128+
with pytest.raises(AssertionError, match="Mismatched elements"):
129+
xp_assert_close(a, b, rtol=0.01)
130+
131+
xp_assert_close(a, b, atol=3)
132+
with pytest.raises(AssertionError, match="Mismatched elements"):
133+
xp_assert_close(a, b, atol=1)
134+
135+
def test_assert_less(self, xp: ModuleType):
136+
xp_assert_less(xp.asarray(-1), xp.asarray(0))
137+
xp_assert_less(xp.asarray([1, 2]), xp.asarray([2, 3]))
138+
with pytest.raises(AssertionError, match="Mismatched elements"):
139+
xp_assert_less(xp.asarray([1, 1]), xp.asarray([2, 1]))
140+
141+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
142+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="index by sparse array")
143+
@pytest.mark.skip_xp_backend(Backend.ARRAY_API_STRICTEST, reason="boolean indexing")
144+
def test_none_shape(self, xp: ModuleType, func: Callable[..., None]):
145+
"""On Dask and other lazy backends, test that a shape with NaN's or None's
146+
can be compared to a real shape.
147+
"""
148+
# actual has shape=(None, )
149+
a = xp.asarray([1] if func is xp_assert_less else [2])
150+
a = a[a > 0]
151+
152+
func(a, xp.asarray([2]))
153+
with pytest.raises(AssertionError, match="shapes do not match"):
154+
func(a, xp.asarray(2))
155+
with pytest.raises(AssertionError, match="shapes do not match"):
156+
func(a, xp.asarray([2, 3]))
157+
with pytest.raises(AssertionError, match="Mismatched elements"):
158+
func(a, xp.asarray([0]))
159+
160+
# desired has shape=(None, )
161+
a = xp.asarray([3] if func is xp_assert_less else [2])
162+
a = a[a > 0]
163+
164+
func(xp.asarray([2]), a)
165+
with pytest.raises(AssertionError, match="shapes do not match"):
166+
func(xp.asarray(2), a)
167+
with pytest.raises(AssertionError, match="shapes do not match"):
168+
func(xp.asarray([2, 3]), a)
169+
with pytest.raises(AssertionError, match="Mismatched elements"):
170+
func(xp.asarray([4]), a)
178171

179172

180173
def good_lazy(x: Array) -> Array:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy