Skip to content

Commit 4a9a060

Browse files
committed
[WIP] ENH: dask+cupy, dask+sparse etc. namespaces
1 parent ecadf5b commit 4a9a060

File tree

4 files changed

+81
-7
lines changed

4 files changed

+81
-7
lines changed

array_api_compat/common/_helpers.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,9 @@ def is_dask_namespace(xp: Namespace) -> bool:
397397
"""
398398
Returns True if `xp` is a Dask namespace.
399399
400-
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
400+
This includes ``dask.array`` itself, the version wrapped by array-api-compat,
401+
and the bespoke namespaces generated by
402+
``array_api_compat.dask.array.wrap_namespace``.
401403
402404
See Also
403405
--------
@@ -411,7 +413,13 @@ def is_dask_namespace(xp: Namespace) -> bool:
411413
is_pydata_sparse_namespace
412414
is_array_api_strict_namespace
413415
"""
414-
return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"}
416+
da_compat_name = _compat_module_name() + '.dask.array'
417+
name = xp.__name__
418+
return (
419+
name in {'dask.array', da_compat_name}
420+
or name.startswith(da_compat_name + '.')
421+
and name[len(da_compat_name) + 1:] not in ("linalg", "fft")
422+
)
415423

416424

417425
def is_jax_namespace(xp: Namespace) -> bool:
@@ -597,9 +605,16 @@ def your_function(x, y):
597605
elif is_dask_array(x):
598606
if _use_compat:
599607
_check_api_version(api_version)
600-
from ..dask import array as dask_namespace
601-
602-
namespaces.add(dask_namespace)
608+
from ..dask.array import wrap_namespace
609+
610+
# The meta-namespace is only used to generate the meta-array, so it
611+
# would be useless to create a namespace such as e.g.
612+
# array_api_compat.dask.array.array_api_compat.cupy.
613+
# It would get worse once you vendor array-api-compat!
614+
# So keep it clean with array_api_compat.dask.array.cupy.
615+
mxp = array_namespace(x._meta, use_compat=False)
616+
xp = wrap_namespace(mxp)
617+
namespaces.add(xp)
603618
else:
604619
import dask.array as da
605620

array_api_compat/dask/array/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
# These imports may overwrite names from the import * above.
66
from ._aliases import * # noqa: F403
7+
from ._meta import wrap_namespace # noqa: F401
78

89
__array_api_version__: Final = "2024.12"
910

array_api_compat/dask/array/_aliases.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def asarray(
152152
dtype: DType | None = None,
153153
device: Device | None = None,
154154
copy: py_bool | None = None,
155+
like: Array | None = None,
155156
**kwargs: object,
156157
) -> Array:
157158
"""
@@ -168,7 +169,11 @@ def asarray(
168169
if copy is False:
169170
raise ValueError("Unable to avoid copy when changing dtype")
170171
obj = obj.astype(dtype)
171-
return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue]
172+
if copy:
173+
obj = obj.copy()
174+
if like is not None:
175+
obj = da.asarray(obj, like=like)
176+
return obj
172177

173178
if copy is False:
174179
raise ValueError(
@@ -177,7 +182,11 @@ def asarray(
177182

178183
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
179184
# see https://github.com/dask/dask/pull/11524/
180-
obj = np.array(obj, dtype=dtype, copy=True)
185+
if like is not None:
186+
mxp = array_namespace(like)
187+
obj = mxp.asarray(obj, dtype=dtype, copy=True)
188+
else:
189+
obj = np.array(obj, dtype=dtype, copy=True)
181190
return da.from_array(obj)
182191

183192

array_api_compat/dask/array/_meta.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import functools
2+
import sys
3+
import types
4+
5+
from ...common._helpers import is_numpy_namespace
6+
from ...common._typing import Namespace
7+
8+
__all__ = ['wrap_namespace']
9+
10+
11+
def wrap_namespace(xp: Namespace) -> Namespace:
12+
"""Create a bespoke Dask namespace that wraps around another namespace.
13+
14+
Parameters
15+
----------
16+
xp : namespace
17+
Namespace to be wrapped by Dask
18+
19+
Returns
20+
-------
21+
namespace :
22+
A module object that duplicates array_api_compat.dask.array, with the
23+
difference that all creation functions will create an array with the same
24+
meta namespace as the input.
25+
"""
26+
from .. import array as da_compat
27+
28+
if is_numpy_namespace(xp):
29+
return da_compat
30+
31+
mod_name = f'{da_compat.__name__}.{xp.__name__}'
32+
try:
33+
return sys.modules[mod_name]
34+
except KeyError:
35+
pass
36+
37+
mod = types.ModuleType(mod_name)
38+
sys.modules[mod_name] = mod
39+
40+
meta = xp.empty(())
41+
for name, v in da_compat.__dict__.items():
42+
if name.startswith('_'):
43+
continue
44+
if name in {'arange', 'asarray', 'empty', 'eye', 'from_dlpack',
45+
'full', 'linspace', 'ones', 'zeros'}:
46+
v = functools.wraps(v)(functools.partial(v, like=meta))
47+
setattr(mod, name, v)
48+
49+
return mod

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy