Skip to content

Commit 8b77b6e

Browse files
committed
[WIP] ENH: dask+cupy, dask+sparse etc. namespaces
1 parent 16978e6 commit 8b77b6e

File tree

4 files changed

+80
-6
lines changed

4 files changed

+80
-6
lines changed

array_api_compat/common/_helpers.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,9 @@ def is_dask_namespace(xp: Namespace) -> bool:
352352
"""
353353
Returns True if `xp` is a Dask namespace.
354354
355-
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
355+
This includes ``dask.array`` itself, the version wrapped by array-api-compat,
356+
and the bespoke namespaces generated by
357+
``array_api_compat.dask.array.wrap_namespace``.
356358
357359
See Also
358360
--------
@@ -366,7 +368,11 @@ def is_dask_namespace(xp: Namespace) -> bool:
366368
is_pydata_sparse_namespace
367369
is_array_api_strict_namespace
368370
"""
369-
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
371+
da_compat_name = _compat_module_name() + '.dask.array'
372+
return (
373+
xp.__name__ in {'dask.array', da_compat_name}
374+
or xp.__name__.startswith(da_compat_name + '.')
375+
)
370376

371377

372378
def is_jax_namespace(xp: Namespace) -> bool:
@@ -543,8 +549,16 @@ def your_function(x, y):
543549
elif is_dask_array(x):
544550
if _use_compat:
545551
_check_api_version(api_version)
546-
from ..dask import array as dask_namespace
547-
namespaces.add(dask_namespace)
552+
from ..dask.array import wrap_namespace
553+
554+
# The meta-namespace is only used to generate the meta-array, so it
555+
# would be useless to create a namespace such as e.g.
556+
# array_api_compat.dask.array.array_api_compat.cupy.
557+
# It would get worse once you vendor array-api-compat!
558+
# So keep it clean with array_api_compat.dask.array.cupy.
559+
mxp = array_namespace(x._meta, use_compat=False)
560+
xp = wrap_namespace(mxp)
561+
namespaces.add(xp)
548562
else:
549563
import dask.array as da
550564
namespaces.add(da)

array_api_compat/dask/array/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# These imports may overwrite names from the import * above.
44
from ._aliases import * # noqa: F403
5+
from ._meta import wrap_namespace # noqa: F401
56

67
__array_api_version__ = '2024.12'
78

array_api_compat/dask/array/_aliases.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def asarray(
148148
dtype: Optional[DType] = None,
149149
device: Optional[Device] = None,
150150
copy: Optional[bool] = None,
151+
like: Optional[Array] = None,
151152
**kwargs,
152153
) -> Array:
153154
"""
@@ -164,7 +165,11 @@ def asarray(
164165
if copy is False:
165166
raise ValueError("Unable to avoid copy when changing dtype")
166167
obj = obj.astype(dtype)
167-
return obj.copy() if copy else obj
168+
if copy:
169+
obj = obj.copy()
170+
if like is not None:
171+
obj = da.asarray(obj, like=like)
172+
return obj
168173

169174
if copy is False:
170175
raise NotImplementedError(
@@ -173,7 +178,11 @@ def asarray(
173178

174179
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
175180
# see https://github.com/dask/dask/pull/11524/
176-
obj = np.array(obj, dtype=dtype, copy=True)
181+
if like is not None:
182+
mxp = array_namespace(like)
183+
obj = mxp.asarray(obj, dtype=dtype, copy=True)
184+
else:
185+
obj = np.array(obj, dtype=dtype, copy=True)
177186
return da.from_array(obj)
178187

179188

array_api_compat/dask/array/_meta.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import functools
2+
import sys
3+
import types
4+
5+
from ...common._helpers import is_numpy_namespace
6+
from ...common._typing import Namespace
7+
8+
__all__ = ['wrap_namespace']
9+
_all_ignore = ['functools', 'sys', 'types', 'is_numpy_namespace']
10+
11+
12+
def wrap_namespace(xp: Namespace) -> Namespace:
13+
"""Create a bespoke Dask namespace that wraps around another namespace.
14+
15+
Parameters
16+
----------
17+
xp : namespace
18+
Namespace to be wrapped by Dask
19+
20+
Returns
21+
-------
22+
namespace :
23+
A module object that duplicates array_api_compat.dask.array, with the
24+
difference that all creation functions will create an array with the same
25+
meta namespace as the input.
26+
"""
27+
from .. import array as da_compat
28+
29+
if is_numpy_namespace(xp):
30+
return da_compat
31+
32+
mod_name = f'{da_compat.__name__}.{xp.__name__}'
33+
try:
34+
return sys.modules[mod_name]
35+
except KeyError:
36+
pass
37+
38+
mod = types.ModuleType(mod_name)
39+
sys.modules[mod_name] = mod
40+
41+
meta = xp.empty(())
42+
for name, v in da_compat.__dict__.items():
43+
if name.startswith('_'):
44+
continue
45+
if name in {'arange', 'asarray', 'empty', 'eye', 'from_dlpack',
46+
'full', 'linspace', 'ones', 'zeros'}:
47+
v = functools.wraps(v)(functools.partial(v, like=meta))
48+
setattr(mod, name, v)
49+
50+
return mod

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy