Skip to content

Commit 6bea970

Browse files
authored
Add function Eye to the Array API (#29)
* Add function Eye to the Array API * remove eye * improve * fix overflow
1 parent 35cb298 commit 6bea970

File tree

10 files changed

+157
-10
lines changed

10 files changed

+157
-10
lines changed

_unittests/onnx-numpy-skips.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
# API failures
22
# see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt
3-
array_api_tests/test_creation_functions.py::test_asarray_scalars
4-
array_api_tests/test_creation_functions.py::test_arange
3+
# uses __setitem__
54
array_api_tests/test_creation_functions.py::test_asarray_arrays
65
array_api_tests/test_creation_functions.py::test_empty
76
array_api_tests/test_creation_functions.py::test_empty_like
8-
array_api_tests/test_creation_functions.py::test_eye
97
array_api_tests/test_creation_functions.py::test_linspace
108
array_api_tests/test_creation_functions.py::test_meshgrid

_unittests/test_array_api.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros_like || exit 1
2+
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_eye || exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
44
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1

_unittests/ut_array_api/test_hypothesis_array_api.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def sh(x):
3939

4040
class TestHypothesisArraysApis(ExtTestCase):
4141
MAX_ARRAY_SIZE = 10000
42+
SQRT_MAX_ARRAY_SIZE = int(10000**0.5)
4243
VERSION = "2021.12"
4344

4445
@classmethod
@@ -138,9 +139,80 @@ def fctonx(x, kw):
138139
fctonx()
139140
self.assertEqual(len(args_onxp), len(args_np))
140141

142+
def test_square_sizes_strategies(self):
143+
dtypes = dict(
144+
integer_dtypes=self.xps.integer_dtypes(),
145+
uinteger_dtypes=self.xps.unsigned_integer_dtypes(),
146+
floating_dtypes=self.xps.floating_dtypes(),
147+
numeric_dtypes=self.xps.numeric_dtypes(),
148+
boolean_dtypes=self.xps.boolean_dtypes(),
149+
scalar_dtypes=self.xps.scalar_dtypes(),
150+
)
151+
152+
dtypes_onnx = dict(
153+
integer_dtypes=self.onxps.integer_dtypes(),
154+
uinteger_dtypes=self.onxps.unsigned_integer_dtypes(),
155+
floating_dtypes=self.onxps.floating_dtypes(),
156+
numeric_dtypes=self.onxps.numeric_dtypes(),
157+
boolean_dtypes=self.onxps.boolean_dtypes(),
158+
scalar_dtypes=self.onxps.scalar_dtypes(),
159+
)
160+
161+
for k, vnp in dtypes.items():
162+
vonxp = dtypes_onnx[k]
163+
anp = self.xps.arrays(dtype=vnp, shape=shapes(self.xps))
164+
aonxp = self.onxps.arrays(dtype=vonxp, shape=shapes(self.onxps))
165+
self.assertNotEmpty(anp)
166+
self.assertNotEmpty(aonxp)
167+
168+
args_np = []
169+
170+
kws = array_api_kwargs(k=strategies.integers(), dtype=self.xps.numeric_dtypes())
171+
sqrt_sizes = strategies.integers(0, self.SQRT_MAX_ARRAY_SIZE)
172+
ncs = strategies.none() | sqrt_sizes
173+
174+
@given(n_rows=sqrt_sizes, n_cols=ncs, kw=kws)
175+
def fctnp(n_rows, n_cols, kw):
176+
base = np.asarray(0)
177+
e = np.eye(n_rows, n_cols)
178+
self.assertNotEmpty(e.dtype)
179+
self.assertIsInstance(e, base.__class__)
180+
e = np.eye(n_rows, n_cols, **kw)
181+
self.assertNotEmpty(e.dtype)
182+
self.assertIsInstance(e, base.__class__)
183+
args_np.append((n_rows, n_cols, kw))
184+
185+
fctnp()
186+
self.assertEqual(len(args_np), 100)
187+
188+
args_onxp = []
189+
190+
kws = array_api_kwargs(
191+
k=strategies.integers(), dtype=self.onxps.numeric_dtypes()
192+
)
193+
sqrt_sizes = strategies.integers(0, self.SQRT_MAX_ARRAY_SIZE)
194+
ncs = strategies.none() | sqrt_sizes
195+
196+
@given(n_rows=sqrt_sizes, n_cols=ncs, kw=kws)
197+
def fctonx(n_rows, n_cols, kw):
198+
base = onxp.asarray(0)
199+
e = onxp.eye(n_rows, n_cols)
200+
self.assertIsInstance(e, base.__class__)
201+
self.assertNotEmpty(e.dtype)
202+
e = onxp.eye(n_rows, n_cols, **kw)
203+
self.assertNotEmpty(e.dtype)
204+
self.assertIsInstance(e, base.__class__)
205+
args_onxp.append((n_rows, n_cols, kw))
206+
207+
fctonx()
208+
self.assertEqual(len(args_onxp), len(args_np))
209+
141210

142211
if __name__ == "__main__":
143212
# cl = TestHypothesisArraysApis()
144213
# cl.setUpClass()
145214
# cl.test_scalar_strategies()
215+
# import logging
216+
217+
# logging.basicConfig(level=logging.DEBUG)
146218
unittest.main(verbosity=2)

_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,28 @@ def test_as_array(self):
142142
self.assertEqual(r.dtype, DType(TensorProto.UINT64))
143143
self.assertEqual(r.numpy(), 9223372036854775809)
144144

145+
def test_eye(self):
146+
nr, nc = xp.asarray(4), xp.asarray(4)
147+
expected = np.eye(nr.numpy(), nc.numpy())
148+
got = xp.eye(nr, nc)
149+
self.assertEqualArray(expected, got.numpy())
150+
151+
def test_eye_nosquare(self):
152+
nr, nc = xp.asarray(4), xp.asarray(5)
153+
expected = np.eye(nr.numpy(), nc.numpy())
154+
got = xp.eye(nr, nc)
155+
self.assertEqualArray(expected, got.numpy())
156+
157+
def test_eye_k(self):
158+
nr = xp.asarray(4)
159+
expected = np.eye(nr.numpy(), k=1)
160+
got = xp.eye(nr, k=1)
161+
self.assertEqualArray(expected, got.numpy())
162+
145163

146164
if __name__ == "__main__":
147165
# import logging
148166

149167
# logging.basicConfig(level=logging.DEBUG)
150-
# TestOnnxNumpy().test_as_array()
168+
TestOnnxNumpy().test_eye()
151169
unittest.main(verbosity=2)

onnx_array_api/array_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"astype",
1818
"empty",
1919
"equal",
20+
"eye",
2021
"full",
2122
"full_like",
2223
"isdtype",

onnx_array_api/array_api/_onnx_common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Optional
22
import warnings
33
import numpy as np
4+
from onnx import TensorProto
45

56
with warnings.catch_warnings():
67
warnings.simplefilter("ignore")
@@ -19,6 +20,8 @@
1920
from ..npx.npx_functions import (
2021
abs as generic_abs,
2122
arange as generic_arange,
23+
copy as copy_inline,
24+
eye as generic_eye,
2225
full as generic_full,
2326
full_like as generic_full_like,
2427
ones as generic_ones,
@@ -185,6 +188,24 @@ def full(
185188
return generic_full(shape, fill_value=value, dtype=dtype, order=order)
186189

187190

191+
def eye(
192+
TEagerTensor: type,
193+
n_rows: TensorType[ElemType.int64, "I"],
194+
n_cols: OptTensorType[ElemType.int64, "I"] = None,
195+
/,
196+
*,
197+
k: ParType[int] = 0,
198+
dtype: ParType[DType] = DType(TensorProto.DOUBLE),
199+
):
200+
if isinstance(n_rows, int):
201+
n_rows = TEagerTensor(np.array(n_rows, dtype=np.int64))
202+
if n_cols is None:
203+
n_cols = n_rows
204+
elif isinstance(n_cols, int):
205+
n_cols = TEagerTensor(np.array(n_cols, dtype=np.int64))
206+
return generic_eye(n_rows, n_cols, k=k, dtype=dtype)
207+
208+
188209
def full_like(
189210
TEagerTensor: type,
190211
x: TensorType[ElemType.allowed, "T"],

onnx_array_api/npx/npx_functions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,30 @@ def expit(
473473
return var(x, op="Sigmoid")
474474

475475

476+
@npxapi_inline
477+
def eye(
478+
n_rows: TensorType[ElemType.int64, "I"],
479+
n_cols: TensorType[ElemType.int64, "I"],
480+
/,
481+
*,
482+
k: ParType[int] = 0,
483+
dtype: ParType[DType] = DType(TensorProto.DOUBLE),
484+
):
485+
"See :func:`numpy.eye`."
486+
shape = cst(np.array([-1], dtype=np.int64))
487+
shape = var(
488+
var(n_rows, shape, op="Reshape"),
489+
var(n_cols, shape, op="Reshape"),
490+
axis=0,
491+
op="Concat",
492+
)
493+
zero = zeros(shape, dtype=dtype)
494+
res = var(zero, k=k, op="EyeLike")
495+
if dtype is not None:
496+
return var(res, to=dtype.code, op="Cast")
497+
return res
498+
499+
476500
@npxapi_inline
477501
def full(
478502
shape: TensorType[ElemType.int64, "I", (None,)],

onnx_array_api/npx/npx_graph_builder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,11 @@ def make_node(
230230
new_kwargs[k] = v.value
231231
elif isinstance(v, DType):
232232
new_kwargs[k] = v.code
233+
elif isinstance(v, int):
234+
try:
235+
new_kwargs[k] = int(np.array(v, dtype=np.int64))
236+
except OverflowError:
237+
new_kwargs[k] = int(np.iinfo(np.int64).max)
233238
else:
234239
new_kwargs[k] = v
235240

@@ -246,6 +251,11 @@ def make_node(
246251
f"Unable to create node {op!r}, with inputs={inputs}, "
247252
f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}."
248253
) from e
254+
except ValueError as e:
255+
raise ValueError(
256+
f"Unable to create node {op!r}, with inputs={inputs}, "
257+
f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}."
258+
) from e
249259
for p in protos:
250260
node.attribute.append(p)
251261
if attribute_protos is not None:

onnx_array_api/npx/npx_jit_eager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,11 +510,18 @@ def jit_call(self, *values, **kwargs):
510510
from ..plotting.text_plot import onnx_simple_text_plot
511511

512512
text = onnx_simple_text_plot(self.onxs[key])
513+
514+
def catch_len(x):
515+
try:
516+
return len(x)
517+
except TypeError:
518+
return 0
519+
513520
raise RuntimeError(
514521
f"Unable to run function for key={key!r}, "
515522
f"types={[type(x) for x in values]}, "
516523
f"dtypes={[getattr(x, 'dtype', type(x)) for x in values]}, "
517-
f"shapes={[getattr(x, 'shape', len(x)) for x in values]}, "
524+
f"shapes={[getattr(x, 'shape', catch_len(x)) for x in values]}, "
518525
f"kwargs={kwargs}, "
519526
f"self.input_to_kwargs_={self.input_to_kwargs_}, "
520527
f"f={self.f} from module {self.f.__module__!r} "

onnx_array_api/reference/evaluator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
from .ops.op_cast_like import CastLike_15, CastLike_19
88
from .ops.op_constant_of_shape import ConstantOfShape
99

10-
import onnx
11-
12-
print(onnx.__file__)
13-
1410

1511
logger = getLogger("onnx-array-api-eval")
1612

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy