Skip to content

Commit 7f39b03

Browse files
committed
[WIP] ENH: dask+cupy, dask+sparse etc. namespaces
1 parent e14754b commit 7f39b03

File tree

4 files changed

+70
-5
lines changed

4 files changed

+70
-5
lines changed

array_api_compat/common/_helpers.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,11 @@ def is_dask_namespace(xp) -> bool:
368368
is_pydata_sparse_namespace
369369
is_array_api_strict_namespace
370370
"""
371-
return xp.__name__ in {'dask.array', _compat_module_name() + '.dask.array'}
371+
names = {'dask.array', _compat_module_name() + '.dask.array'}
372+
return (
373+
xp.__name__ in names
374+
or any(xp.__name__.startswith(name + '.') for name in names)
375+
)
372376

373377

374378
def is_jax_namespace(xp) -> bool:
@@ -541,8 +545,10 @@ def your_function(x, y):
541545
elif is_dask_array(x):
542546
if _use_compat:
543547
_check_api_version(api_version)
544-
from ..dask import array as dask_namespace
545-
namespaces.add(dask_namespace)
548+
from ..dask.array import wrap_namespace
549+
mxp = array_namespace(x._meta, use_compat=False)
550+
xp = wrap_namespace(mxp)
551+
namespaces.add(xp)
546552
else:
547553
import dask.array as da
548554
namespaces.add(da)

array_api_compat/dask/array/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# These imports may overwrite names from the import * above.
44
from ._aliases import * # noqa: F403
5+
from ._meta import wrap_namespace # noqa: F401
56

67
__array_api_version__ = '2024.12'
78

array_api_compat/dask/array/_aliases.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def asarray(
157157
dtype: Optional[Dtype] = None,
158158
device: Optional[Device] = None,
159159
copy: Optional[Union[bool, np._CopyMode]] = None,
160+
like: Optional[Array] = None,
160161
**kwargs,
161162
) -> Array:
162163
"""
@@ -172,7 +173,11 @@ def asarray(
172173
if copy is False:
173174
raise ValueError("Unable to avoid copy when changing dtype")
174175
obj = obj.astype(dtype)
175-
return obj.copy() if copy else obj
176+
if copy:
177+
obj = obj.copy()
178+
if like is not None:
179+
obj = da.asarray(obj, like=like)
180+
return obj
176181

177182
if copy is False:
178183
raise NotImplementedError(
@@ -181,7 +186,11 @@ def asarray(
181186

182187
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
183188
# see https://github.com/dask/dask/pull/11524/
184-
obj = np.array(obj, dtype=dtype, copy=True)
189+
if like is not None:
190+
mxp = array_namespace(like)
191+
obj = mxp.asarray(obj, dtype=dtype, copy=True)
192+
else:
193+
obj = np.array(obj, dtype=dtype, copy=True)
185194
return da.from_array(obj)
186195

187196

array_api_compat/dask/array/_meta.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import functools
2+
import sys
3+
import types
4+
5+
from ...common import is_numpy_namespace
6+
7+
__all__ = ['wrap_namespace']
8+
_all_ignore = ['functools', 'sys', 'types', 'is_numpy_namespace']
9+
10+
11+
def wrap_namespace(xp):
12+
"""Create a bespoke Dask namespace that wraps around another namespace.
13+
14+
Parameters
15+
----------
16+
xp : namespace
17+
Namespace to be wrapped by Dask
18+
19+
Returns
20+
-------
21+
namespace :
22+
A module object that duplicates array_api_compat.dask.array, with the
23+
difference that all creation functions will create an array with the same
24+
meta namespace as the input.
25+
"""
26+
from .. import array as da_compat
27+
28+
if is_numpy_namespace(xp):
29+
return da_compat
30+
31+
mod_name = f'{da_compat.__name__}.{xp.__name__}'
32+
try:
33+
return sys.modules[mod_name]
34+
except KeyError:
35+
pass
36+
37+
mod = types.ModuleType(mod_name)
38+
sys.modules[mod_name] = mod
39+
40+
meta = xp.empty(())
41+
for name, v in da_compat.__dict__.items():
42+
if name.startswith('_'):
43+
continue
44+
if name in {'arange', 'asarray', 'empty', 'eye', 'from_dlpack',
45+
'full', 'linspace', 'ones', 'zeros'}:
46+
v = functools.wraps(v)(functools.partial(v, like=meta))
47+
setattr(mod, name, v)
48+
49+
return mod

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy