Skip to content

Commit 27fbd9c

Browse files
authored
ENH: support PyTorch device='meta' (#300)
1 parent e2762f5 commit 27fbd9c

File tree

8 files changed

+111
-29
lines changed

8 files changed

+111
-29
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def _apply_where( # type: ignore[explicit-any] # numpydoc ignore=PR01,RT01
153153
) -> Array:
154154
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
155155

156-
if not capabilities(xp)["boolean indexing"]:
156+
if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]:
157157
# jax.jit does not support assignment by boolean mask
158158
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
159159

@@ -716,7 +716,7 @@ def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
716716
# 2. backend has unique_counts and it returns a None-sized array;
717717
# e.g. Dask, ndonnx
718718
# 3. backend does not have unique_counts; e.g. wrapped JAX
719-
if capabilities(xp)["data-dependent shapes"]:
719+
if capabilities(xp, device=_compat.device(x))["data-dependent shapes"]:
720720
# xp has unique_counts; O(n) complexity
721721
_, counts = xp.unique_counts(x)
722722
n = _compat.size(counts)

src/array_api_extra/_lib/_testing.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
is_jax_namespace,
2323
is_numpy_namespace,
2424
is_pydata_sparse_namespace,
25+
is_torch_array,
2526
is_torch_namespace,
2627
to_device,
2728
)
@@ -62,18 +63,28 @@ def _check_ns_shape_dtype(
6263
msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
6364
assert actual_xp == desired_xp, msg
6465

65-
if check_shape:
66-
actual_shape = actual.shape
67-
desired_shape = desired.shape
68-
if is_dask_namespace(desired_xp):
69-
# Dask uses nan instead of None for unknown shapes
70-
if any(math.isnan(i) for i in cast(tuple[float, ...], actual_shape)):
71-
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
72-
if any(math.isnan(i) for i in cast(tuple[float, ...], desired_shape)):
73-
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
66+
# Dask uses nan instead of None for unknown shapes
67+
actual_shape = cast(tuple[float, ...], actual.shape)
68+
desired_shape = cast(tuple[float, ...], desired.shape)
69+
assert None not in actual_shape # Requires explicit support
70+
assert None not in desired_shape
71+
if is_dask_namespace(desired_xp):
72+
if any(math.isnan(i) for i in actual_shape):
73+
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
74+
if any(math.isnan(i) for i in desired_shape):
75+
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
7476

77+
if check_shape:
7578
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
7679
assert actual_shape == desired_shape, msg
80+
else:
81+
# Ignore shape, but check flattened size. This is normally done by
82+
# np.testing.assert_array_equal etc even when strict=False, but not for
83+
# non-materializable arrays.
84+
actual_size = math.prod(actual_shape) # pyright: ignore[reportUnknownArgumentType]
85+
desired_size = math.prod(desired_shape) # pyright: ignore[reportUnknownArgumentType]
86+
msg = f"sizes do not match: {actual_size} != f{desired_size}"
87+
assert actual_size == desired_size, msg
7788

7889
if check_dtype:
7990
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
@@ -90,6 +101,15 @@ def _check_ns_shape_dtype(
90101
return desired_xp
91102

92103

104+
def _is_materializable(x: Array) -> bool:
105+
"""
106+
Return True if you can call `as_numpy_array(x)`; False otherwise.
107+
"""
108+
# Important: here we assume that we're not tracing -
109+
# e.g. we're not inside `jax.jit`` nor `cupy.cuda.Stream.begin_capture`.
110+
return not is_torch_array(x) or x.device.type != "meta" # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
111+
112+
93113
def as_numpy_array(array: Array, *, xp: ModuleType) -> np.typing.NDArray[Any]: # type: ignore[explicit-any]
94114
"""
95115
Convert array to NumPy, bypassing GPU-CPU transfer guards and densification guards.
@@ -146,6 +166,8 @@ def xp_assert_equal(
146166
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
147167
"""
148168
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
169+
if not _is_materializable(actual):
170+
return
149171
actual_np = as_numpy_array(actual, xp=xp)
150172
desired_np = as_numpy_array(desired, xp=xp)
151173
np.testing.assert_array_equal(actual_np, desired_np, err_msg=err_msg)
@@ -181,6 +203,8 @@ def xp_assert_less(
181203
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
182204
"""
183205
xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
206+
if not _is_materializable(x):
207+
return
184208
x_np = as_numpy_array(x, xp=xp)
185209
y_np = as_numpy_array(y, xp=xp)
186210
np.testing.assert_array_less(x_np, y_np, err_msg=err_msg)
@@ -229,6 +253,8 @@ def xp_assert_close(
229253
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
230254
"""
231255
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
256+
if not _is_materializable(actual):
257+
return
232258

233259
if rtol is None:
234260
if xp.isdtype(actual.dtype, ("real floating", "complex floating")):

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
is_jax_namespace,
3030
is_numpy_array,
3131
is_pydata_sparse_namespace,
32+
is_torch_namespace,
3233
)
33-
from ._typing import Array
34+
from ._typing import Array, Device
3435

3536
if TYPE_CHECKING: # pragma: no cover
3637
# TODO import from typing (requires Python >=3.12 and >=3.13)
@@ -300,7 +301,7 @@ def meta_namespace(
300301
return array_namespace(*metas)
301302

302303

303-
def capabilities(xp: ModuleType) -> dict[str, int]:
304+
def capabilities(xp: ModuleType, *, device: Device | None = None) -> dict[str, int]:
304305
"""
305306
Return patched ``xp.__array_namespace_info__().capabilities()``.
306307
@@ -311,6 +312,8 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
311312
----------
312313
xp : array_namespace
313314
The standard-compatible namespace.
315+
device : Device, optional
316+
The device to use.
314317
315318
Returns
316319
-------
@@ -326,6 +329,13 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
326329
# Fixed in jax >=0.6.0
327330
out = out.copy()
328331
out["boolean indexing"] = False
332+
if is_torch_namespace(xp):
333+
# FIXME https://github.com/data-apis/array-api/issues/945
334+
device = xp.get_default_device() if device is None else xp.device(device)
335+
if device.type == "meta": # type: ignore[union-attr] # pyright: ignore[reportAttributeAccessIssue,reportOptionalMemberAccess]
336+
out = out.copy()
337+
out["boolean indexing"] = False
338+
out["data-dependent shapes"] = False
329339
return out
330340

331341

tests/conftest.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,9 @@ def device(
211211
Where possible, return a device that is not the default one.
212212
"""
213213
if library == Backend.ARRAY_API_STRICT:
214-
d = xp.Device("device1")
215-
assert get_device(xp.empty(0)) != d
216-
return d
214+
return xp.Device("device1")
215+
if library == Backend.TORCH:
216+
return xp.device("meta")
217+
if library == Backend.TORCH_GPU:
218+
return xp.device("cpu")
217219
return get_device(xp.empty(0))

tests/test_funcs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -731,9 +731,6 @@ def test_device(self, xp: ModuleType, device: Device, equal_nan: bool):
731731
b = xp.asarray([1e-9, 1e-4, xp.nan], device=device)
732732
res = isclose(a, b, equal_nan=equal_nan)
733733
assert get_device(res) == device
734-
xp_assert_equal(
735-
isclose(a, b, equal_nan=equal_nan), xp.asarray([True, False, equal_nan])
736-
)
737734

738735

739736
class TestKron:
@@ -996,6 +993,9 @@ def test_all_python_scalars(self, assume_unique: bool):
996993
_ = setdiff1d(0, 0, assume_unique=assume_unique)
997994

998995
@assume_unique
996+
@pytest.mark.skip_xp_backend(
997+
Backend.TORCH, reason="device='meta' does not support unknown shapes"
998+
)
999999
def test_device(self, xp: ModuleType, device: Device, assume_unique: bool):
10001000
x1 = xp.asarray([3, 8, 20], device=device)
10011001
x2 = xp.asarray([2, 3, 4], device=device)

tests/test_helpers.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,31 @@ def test_xp(self, xp: ModuleType):
212212
assert meta_namespace(*args, xp=xp) in (xp, np_compat)
213213

214214

215-
def test_capabilities(xp: ModuleType):
216-
expect = {"boolean indexing", "data-dependent shapes"}
217-
if xp.__array_api_version__ >= "2024.12":
218-
expect.add("max dimensions")
219-
assert capabilities(xp).keys() == expect
215+
class TestCapabilities:
216+
def test_basic(self, xp: ModuleType):
217+
expect = {"boolean indexing", "data-dependent shapes"}
218+
if xp.__array_api_version__ >= "2024.12":
219+
expect.add("max dimensions")
220+
assert capabilities(xp).keys() == expect
221+
222+
def test_device(self, xp: ModuleType, library: Backend, device: Device):
223+
expect_keys = {"boolean indexing", "data-dependent shapes"}
224+
if xp.__array_api_version__ >= "2024.12":
225+
expect_keys.add("max dimensions")
226+
assert capabilities(xp, device=device).keys() == expect_keys
227+
228+
if library.like(Backend.TORCH):
229+
# The output of capabilities is device-specific.
230+
231+
# Test that device=None gets the current default device.
232+
expect = capabilities(xp, device=device)
233+
with xp.device(device):
234+
actual = capabilities(xp)
235+
assert actual == expect
236+
237+
# Test that we're accepting anything that is accepted by the
238+
# device= parameter in other functions
239+
actual = capabilities(xp, device=device.type) # type: ignore[attr-defined] # pyright: ignore[reportUnknownArgumentType,reportAttributeAccessIssue]
220240

221241

222242
class Wrapper(Generic[T]):

tests/test_lazy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ def test_lazy_apply_none_shape_broadcast(xp: ModuleType):
278278
Backend.ARRAY_API_STRICT, reason="device->host copy"
279279
),
280280
pytest.mark.skip_xp_backend(Backend.CUPY, reason="device->host copy"),
281+
pytest.mark.skip_xp_backend(
282+
Backend.TORCH, reason="materialize 'meta' device"
283+
),
281284
pytest.mark.skip_xp_backend(
282285
Backend.TORCH_GPU, reason="device->host copy"
283286
),

tests/test_testing.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,17 @@
2424
# pyright: reportUnknownParameterType=false,reportMissingParameterType=false
2525

2626

27-
def test_as_numpy_array(xp: ModuleType, device: Device):
28-
x = xp.asarray([1, 2, 3], device=device)
29-
y = as_numpy_array(x, xp=xp)
30-
assert isinstance(y, np.ndarray)
27+
class TestAsNumPyArray:
28+
def test_basic(self, xp: ModuleType):
29+
x = xp.asarray([1, 2, 3])
30+
y = as_numpy_array(x, xp=xp)
31+
xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
32+
33+
@pytest.mark.skip_xp_backend(Backend.TORCH, reason="materialize 'meta' device")
34+
def test_device(self, xp: ModuleType, device: Device):
35+
x = xp.asarray([1, 2, 3], device=device)
36+
y = as_numpy_array(x, xp=xp)
37+
xp_assert_equal(y, np.asarray([1, 2, 3])) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
3138

3239

3340
class TestAssertEqualCloseLess:
@@ -80,7 +87,7 @@ def test_check_shape(self, xp: ModuleType, func: Callable[..., None]):
8087
func(a, b, check_shape=False)
8188
with pytest.raises(AssertionError, match="Mismatched elements"):
8289
func(a, c, check_shape=False)
83-
with pytest.raises(AssertionError, match=r"shapes \(1,\), \(2,\) mismatch"):
90+
with pytest.raises(AssertionError, match="sizes do not match"):
8491
func(a, d, check_shape=False)
8592

8693
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
@@ -169,6 +176,20 @@ def test_none_shape(self, xp: ModuleType, func: Callable[..., None]):
169176
with pytest.raises(AssertionError, match="Mismatched elements"):
170177
func(xp.asarray([4]), a)
171178

179+
@pytest.mark.parametrize("func", [xp_assert_equal, pr_assert_close, xp_assert_less])
180+
def test_device(self, xp: ModuleType, device: Device, func: Callable[..., None]):
181+
a = xp.asarray([1] if func is xp_assert_less else [2], device=device)
182+
b = xp.asarray([2], device=device)
183+
c = xp.asarray([2, 2], device=device)
184+
185+
func(a, b)
186+
with pytest.raises(AssertionError, match="shapes do not match"):
187+
func(a, c)
188+
# This is normally performed by np.testing.assert_array_equal etc.
189+
# but in case of torch device='meta' we have to do it manually
190+
with pytest.raises(AssertionError, match="sizes do not match"):
191+
func(a, c, check_shape=False)
192+
172193

173194
def good_lazy(x: Array) -> Array:
174195
"""A function that behaves well in Dask and jax.jit"""

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy