Skip to content

Commit 40603a9

Browse files
authored
Merge pull request #95 from asmeurer/revert-all-changes2
Revert __all__ related changes from #82
2 parents ab74e4a + a73388d commit 40603a9

24 files changed

+450
-1083
lines changed

.github/workflows/ruff.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ jobs:
1616
pip install ruff
1717
# Update output format to enable automatic inline annotations.
1818
- name: Run Ruff
19-
run: ruff check --output-format=github --select F822,PLC0414,RUF022 --preview .
19+
run: ruff check --output-format=github .

array_api_compat/_internal.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from functools import wraps
66
from inspect import signature
77

8-
98
def get_xp(xp):
109
"""
1110
Decorator to automatically replace xp with the corresponding array module.
@@ -45,31 +44,3 @@ def wrapped_f(*args, **kwargs):
4544
return wrapped_f
4645

4746
return inner
48-
49-
50-
def _get_all_public_members(module, exclude=None, extend_all=False):
51-
"""Get all public members of a module.
52-
53-
Parameters
54-
----------
55-
module : module
56-
The module to get members from.
57-
exclude : callable, optional
58-
A callable that takes a name and returns True if the name should be
59-
excluded from the list of members.
60-
extend_all : bool, optional
61-
If True, extend the module's __all__ attribute with the members of the
62-
module derived from dir(module). To be used for libraries that do not have a complete __all__ list.
63-
"""
64-
members = getattr(module, "__all__", [])
65-
66-
if members and not extend_all:
67-
return members
68-
69-
if exclude is None:
70-
exclude = lambda name: name.startswith("_") # noqa: E731
71-
72-
members = members + [_ for _ in dir(module) if not exclude(_)]
73-
74-
# remove duplicates
75-
return list(set(members))

array_api_compat/common/__init__.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1 @@
1-
from ._helpers import (
2-
array_namespace,
3-
device,
4-
get_namespace,
5-
is_array_api_obj,
6-
is_cupy_array,
7-
is_dask_array,
8-
is_jax_array,
9-
is_numpy_array,
10-
is_torch_array,
11-
size,
12-
to_device,
13-
)
14-
15-
__all__ = [
16-
"array_namespace",
17-
"device",
18-
"get_namespace",
19-
"is_array_api_obj",
20-
"is_cupy_array",
21-
"is_dask_array",
22-
"is_jax_array",
23-
"is_numpy_array",
24-
"is_torch_array",
25-
"size",
26-
"to_device",
27-
]
1+
from ._helpers import * # noqa: F403

array_api_compat/common/_aliases.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ def zeros_like(
146146

147147
# The functions here return namedtuples (np.unique() returns a normal
148148
# tuple).
149+
150+
# Note that these named tuples aren't actually part of the standard namespace,
151+
# but I don't see any issue with exporting the names here regardless.
149152
class UniqueAllResult(NamedTuple):
150153
values: ndarray
151154
indices: ndarray
@@ -545,3 +548,11 @@ def isdtype(
545548
# more strict here to match the type annotation? Note that the
546549
# array_api_strict implementation will be very strict.
547550
return dtype == kind
551+
552+
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
553+
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
554+
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
555+
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
556+
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
557+
'sort', 'nonzero', 'sum', 'prod', 'ceil', 'floor', 'trunc',
558+
'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

array_api_compat/common/_helpers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,19 @@ def size(x):
288288
if None in x.shape:
289289
return None
290290
return math.prod(x.shape)
291+
292+
__all__ = [
293+
"array_namespace",
294+
"device",
295+
"get_namespace",
296+
"is_array_api_obj",
297+
"is_cupy_array",
298+
"is_dask_array",
299+
"is_jax_array",
300+
"is_numpy_array",
301+
"is_torch_array",
302+
"size",
303+
"to_device",
304+
]
305+
306+
_all_ignore = ['sys', 'math', 'inspect']

array_api_compat/common/_linalg.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
else:
1212
from numpy.core.numeric import normalize_axis_tuple
1313

14-
from ._aliases import matrix_transpose, isdtype
14+
from ._aliases import matmul, matrix_transpose, tensordot, vecdot, isdtype
1515
from .._internal import get_xp
1616

1717
# These are in the main NumPy namespace but not in numpy.linalg
@@ -149,4 +149,10 @@ def trace(x: ndarray, /, xp, *, offset: int = 0, dtype=None, **kwargs) -> ndarra
149149
dtype = xp.float64
150150
elif x.dtype == xp.complex64:
151151
dtype = xp.complex128
152-
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
152+
return xp.asarray(xp.trace(x, offset=offset, dtype=dtype, axis1=-2, axis2=-1, **kwargs))
153+
154+
__all__ = ['cross', 'matmul', 'outer', 'tensordot', 'EighResult',
155+
'QRResult', 'SlogdetResult', 'SVDResult', 'eigh', 'qr', 'slogdet',
156+
'svd', 'cholesky', 'matrix_rank', 'pinv', 'matrix_norm',
157+
'matrix_transpose', 'svdvals', 'vecdot', 'vector_norm', 'diagonal',
158+
'trace']

array_api_compat/common/_typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ def __len__(self, /) -> int: ...
2020
SupportsBufferProtocol = Any
2121

2222
Array = Any
23-
Device = Any
23+
Device = Any

array_api_compat/cupy/__init__.py

Lines changed: 7 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1,153 +1,14 @@
1-
import cupy as _cp
2-
from cupy import * # noqa: F401, F403
1+
from cupy import * # noqa: F403
32

43
# from cupy import * doesn't overwrite these builtin names
5-
from cupy import abs, max, min, round
6-
7-
from .._internal import _get_all_public_members
8-
from ..common._helpers import (
9-
array_namespace,
10-
device,
11-
get_namespace,
12-
is_array_api_obj,
13-
size,
14-
to_device,
15-
)
4+
from cupy import abs, max, min, round # noqa: F401
165

176
# These imports may overwrite names from the import * above.
18-
from ._aliases import (
19-
UniqueAllResult,
20-
UniqueCountsResult,
21-
UniqueInverseResult,
22-
acos,
23-
acosh,
24-
arange,
25-
argsort,
26-
asarray,
27-
asarray_cupy,
28-
asin,
29-
asinh,
30-
astype,
31-
atan,
32-
atan2,
33-
atanh,
34-
bitwise_invert,
35-
bitwise_left_shift,
36-
bitwise_right_shift,
37-
bool,
38-
ceil,
39-
concat,
40-
empty,
41-
empty_like,
42-
eye,
43-
floor,
44-
full,
45-
full_like,
46-
isdtype,
47-
linspace,
48-
matmul,
49-
matrix_transpose,
50-
nonzero,
51-
ones,
52-
ones_like,
53-
permute_dims,
54-
pow,
55-
prod,
56-
reshape,
57-
sort,
58-
std,
59-
sum,
60-
tensordot,
61-
trunc,
62-
unique_all,
63-
unique_counts,
64-
unique_inverse,
65-
unique_values,
66-
var,
67-
vecdot,
68-
zeros,
69-
zeros_like,
70-
)
71-
72-
__all__ = []
73-
74-
__all__ += _get_all_public_members(_cp)
75-
76-
__all__ += [
77-
"abs",
78-
"max",
79-
"min",
80-
"round",
81-
]
82-
83-
__all__ += [
84-
"array_namespace",
85-
"device",
86-
"get_namespace",
87-
"is_array_api_obj",
88-
"size",
89-
"to_device",
90-
]
91-
92-
__all__ += [
93-
"UniqueAllResult",
94-
"UniqueCountsResult",
95-
"UniqueInverseResult",
96-
"acos",
97-
"acosh",
98-
"arange",
99-
"argsort",
100-
"asarray",
101-
"asarray_cupy",
102-
"asin",
103-
"asinh",
104-
"astype",
105-
"atan",
106-
"atan2",
107-
"atanh",
108-
"bitwise_invert",
109-
"bitwise_left_shift",
110-
"bitwise_right_shift",
111-
"bool",
112-
"ceil",
113-
"concat",
114-
"empty",
115-
"empty_like",
116-
"eye",
117-
"floor",
118-
"full",
119-
"full_like",
120-
"isdtype",
121-
"linspace",
122-
"matmul",
123-
"matrix_transpose",
124-
"nonzero",
125-
"ones",
126-
"ones_like",
127-
"permute_dims",
128-
"pow",
129-
"prod",
130-
"reshape",
131-
"sort",
132-
"std",
133-
"sum",
134-
"tensordot",
135-
"trunc",
136-
"unique_all",
137-
"unique_counts",
138-
"unique_inverse",
139-
"unique_values",
140-
"var",
141-
"zeros",
142-
"zeros_like",
143-
]
144-
145-
__all__ += [
146-
"matrix_transpose",
147-
"vecdot",
148-
]
7+
from ._aliases import * # noqa: F403
1498

1509
# See the comment in the numpy __init__.py
151-
__import__(__package__ + ".linalg")
10+
__import__(__package__ + '.linalg')
11+
12+
from ..common._helpers import * # noqa: F401,F403
15213

153-
__array_api_version__ = "2022.12"
14+
__array_api_version__ = '2022.12'

array_api_compat/cupy/_aliases.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
import cupy as cp
66

77
from ..common import _aliases
8-
from ..common import _linalg
9-
108
from .._internal import get_xp
119

1210
asarray = asarray_cupy = partial(_aliases._asarray, namespace='cupy')
1311
asarray.__doc__ = _aliases._asarray.__doc__
12+
del partial
1413

1514
bool = cp.bool_
1615

@@ -74,28 +73,7 @@
7473
else:
7574
isdtype = get_xp(cp)(_aliases.isdtype)
7675

77-
78-
cross = get_xp(cp)(_linalg.cross)
79-
outer = get_xp(cp)(_linalg.outer)
80-
EighResult = _linalg.EighResult
81-
QRResult = _linalg.QRResult
82-
SlogdetResult = _linalg.SlogdetResult
83-
SVDResult = _linalg.SVDResult
84-
eigh = get_xp(cp)(_linalg.eigh)
85-
qr = get_xp(cp)(_linalg.qr)
86-
slogdet = get_xp(cp)(_linalg.slogdet)
87-
svd = get_xp(cp)(_linalg.svd)
88-
cholesky = get_xp(cp)(_linalg.cholesky)
89-
matrix_rank = get_xp(cp)(_linalg.matrix_rank)
90-
pinv = get_xp(cp)(_linalg.pinv)
91-
matrix_norm = get_xp(cp)(_linalg.matrix_norm)
92-
svdvals = get_xp(cp)(_linalg.svdvals)
93-
diagonal = get_xp(cp)(_linalg.diagonal)
94-
trace = get_xp(cp)(_linalg.trace)
95-
96-
# These functions are completely new here. If the library already has them
97-
# (i.e., numpy 2.0), use the library version instead of our wrapper.
98-
if hasattr(cp.linalg, 'vector_norm'):
99-
vector_norm = cp.linalg.vector_norm
100-
else:
101-
vector_norm = get_xp(cp)(_linalg.vector_norm)
76+
__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
77+
'acosh', 'asin', 'asinh', 'atan', 'atan2',
78+
'atanh', 'bitwise_left_shift', 'bitwise_invert',
79+
'bitwise_right_shift', 'concat', 'pow']

array_api_compat/cupy/_typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from __future__ import annotations
22

33
__all__ = [
4+
"ndarray",
45
"Device",
56
"Dtype",
6-
"ndarray",
77
]
88

99
import sys

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy