Skip to content

Commit 22a248e

Browse files
committed
Update test_cross to test broadcastable shapes
This also updates it to only test axes from [min(x1.ndim, x2.ndim), -1], as per data-apis/array-api#740
1 parent 4403061 commit 22a248e

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

array_api_tests/test_linalg.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -148,23 +148,29 @@ def cross_args(draw, dtype_objects=dh.real_dtypes):
148148
in the drawn axis.
149149
150150
"""
151-
shape = list(draw(shapes()))
152-
size = len(shape)
153-
assume(size > 0)
151+
shape1, shape2 = draw(two_mutually_broadcastable_shapes)
152+
min_ndim = min(len(shape1), len(shape2))
153+
assume(min_ndim > 0)
154154

155-
kw = draw(kwargs(axis=integers(-size, size-1)))
155+
kw = draw(kwargs(axis=integers(-min_ndim, -1)))
156156
axis = kw.get('axis', -1)
157-
shape[axis] = 3
158-
shape = tuple(shape)
157+
if draw(booleans()):
158+
# Sometimes allow invalid inputs to test it errors
159+
shape1 = list(shape1)
160+
shape1[axis] = 3
161+
shape1 = tuple(shape1)
162+
shape2 = list(shape2)
163+
shape2[axis] = 3
164+
shape2 = tuple(shape2)
159165

160166
mutual_dtypes = shared(mutually_promotable_dtypes(dtypes=dtype_objects))
161167
arrays1 = arrays(
162168
dtype=mutual_dtypes.map(lambda pair: pair[0]),
163-
shape=shape,
169+
shape=shape1,
164170
)
165171
arrays2 = arrays(
166172
dtype=mutual_dtypes.map(lambda pair: pair[1]),
167-
shape=shape,
173+
shape=shape2,
168174
)
169175
return draw(arrays1), draw(arrays2), kw
170176

@@ -176,15 +182,17 @@ def test_cross(x1_x2_kw):
176182
x1, x2, kw = x1_x2_kw
177183

178184
axis = kw.get('axis', -1)
179-
err = "test_cross produced invalid input. This indicates a bug in the test suite."
180-
assert x1.shape == x2.shape, err
181-
shape = x1.shape
182-
assert x1.shape[axis] == x2.shape[axis] == 3, err
185+
if not (x1.shape[axis] == x2.shape[axis] == 3):
186+
ph.raises(Exception, lambda: xp.cross(x1, x2, **kw),
187+
"cross did not raise an exception for invalid shapes")
188+
return
183189

184190
res = linalg.cross(x1, x2, **kw)
185191

192+
broadcasted_shape = sh.broadcast_shapes(x1.shape, x2.shape)
193+
186194
assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "cross() did not return the correct dtype"
187-
assert res.shape == shape, "cross() did not return the correct shape"
195+
assert res.shape == broadcasted_shape, "cross() did not return the correct shape"
188196

189197
def exact_cross(a, b):
190198
assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy