Skip to content

Commit 5edccab

Browse files
authored
Extends Array API to EagerOrt (#18)
* Extends Array API to EagerOrt * fix empty shape * fix shape * fix azure * refactoring * fix command line * CI * fix CI * fix CI
1 parent 37fe094 commit 5edccab

21 files changed

+444
-93
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ _cache/*
88
dist/*
99
build/*
1010
.eggs/*
11+
.hypothesis/*
1112
*egg-info/*
1213
_doc/auto_examples/*
1314
_doc/examples/_cache/*

_unittests/onnx-numpy-skips.txt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# API failures
2+
array_api_tests/test_creation_functions.py::test_arange
3+
array_api_tests/test_creation_functions.py::test_asarray_scalars
4+
array_api_tests/test_creation_functions.py::test_asarray_arrays
5+
array_api_tests/test_creation_functions.py::test_empty
6+
array_api_tests/test_creation_functions.py::test_empty_like
7+
array_api_tests/test_creation_functions.py::test_eye
8+
array_api_tests/test_creation_functions.py::test_full
9+
array_api_tests/test_creation_functions.py::test_full_like
10+
array_api_tests/test_creation_functions.py::test_linspace
11+
array_api_tests/test_creation_functions.py::test_meshgrid
12+
array_api_tests/test_creation_functions.py::test_ones_like
13+
array_api_tests/test_creation_functions.py::test_zeros_like

_unittests/onnx-ort-skips.txt

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Not implementated by onnxruntime
2+
array_api_tests/test_creation_functions.py::test_arange
3+
array_api_tests/test_creation_functions.py::test_asarray_scalars
4+
array_api_tests/test_creation_functions.py::test_asarray_arrays
5+
array_api_tests/test_creation_functions.py::test_empty
6+
array_api_tests/test_creation_functions.py::test_empty_like
7+
array_api_tests/test_creation_functions.py::test_eye
8+
array_api_tests/test_creation_functions.py::test_full
9+
array_api_tests/test_creation_functions.py::test_full_like
10+
array_api_tests/test_creation_functions.py::test_linspace
11+
array_api_tests/test_creation_functions.py::test_meshgrid
12+
array_api_tests/test_creation_functions.py::test_ones
13+
array_api_tests/test_creation_functions.py::test_ones_like
14+
array_api_tests/test_creation_functions.py::test_zeros
15+
array_api_tests/test_creation_functions.py::test_zeros_like
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import unittest
2+
from inspect import isfunction, ismethod
3+
import numpy as np
4+
from onnx_array_api.ext_test_case import ExtTestCase
5+
from onnx_array_api.array_api import onnx_numpy as xpn
6+
from onnx_array_api.array_api import onnx_ort as xpo
7+
8+
# from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
9+
# from onnx_array_api.ort.ort_tensors import EagerOrtTensor
10+
11+
12+
class TestArraysApis(ExtTestCase):
13+
def test_zeros_numpy_1(self):
14+
c = xpn.zeros(1)
15+
d = c.numpy()
16+
self.assertEqualArray(np.array([0], dtype=np.float32), d)
17+
18+
def test_zeros_ort_1(self):
19+
c = xpo.zeros(1)
20+
d = c.numpy()
21+
self.assertEqualArray(np.array([0], dtype=np.float32), d)
22+
23+
def test_ffinfo(self):
24+
dt = np.float32
25+
fi1 = np.finfo(dt)
26+
fi2 = xpn.finfo(dt)
27+
fi3 = xpo.finfo(dt)
28+
dt1 = fi1.dtype
29+
dt2 = fi2.dtype
30+
dt3 = fi3.dtype
31+
self.assertEqual(dt2, dt3)
32+
self.assertNotEqual(dt1.__class__, dt2.__class__)
33+
mi1 = fi1.min
34+
mi2 = fi2.min
35+
self.assertEqual(mi1, mi2)
36+
mi1 = fi1.smallest_normal
37+
mi2 = fi2.smallest_normal
38+
self.assertEqual(mi1, mi2)
39+
for n in dir(fi1):
40+
if n.startswith("__"):
41+
continue
42+
if n in {"machar"}:
43+
continue
44+
v1 = getattr(fi1, n)
45+
with self.subTest(att=n):
46+
v2 = getattr(fi2, n)
47+
v3 = getattr(fi3, n)
48+
if isfunction(v1) or ismethod(v1):
49+
try:
50+
v1 = v1()
51+
except TypeError:
52+
continue
53+
v2 = v2()
54+
v3 = v3()
55+
if v1 != v2:
56+
raise AssertionError(
57+
f"12: info disagree on name {n!r}: {v1} != {v2}, "
58+
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
59+
f"ismethod={ismethod(v1)}."
60+
)
61+
if v2 != v3:
62+
raise AssertionError(
63+
f"23: info disagree on name {n!r}: {v2} != {v3}, "
64+
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
65+
f"ismethod={ismethod(v1)}."
66+
)
67+
68+
def test_iiinfo(self):
69+
dt = np.int64
70+
fi1 = np.iinfo(dt)
71+
fi2 = xpn.iinfo(dt)
72+
fi3 = xpo.iinfo(dt)
73+
dt1 = fi1.dtype
74+
dt2 = fi2.dtype
75+
dt3 = fi3.dtype
76+
self.assertEqual(dt2, dt3)
77+
self.assertNotEqual(dt1.__class__, dt2.__class__)
78+
mi1 = fi1.min
79+
mi2 = fi2.min
80+
self.assertEqual(mi1, mi2)
81+
for n in dir(fi1):
82+
if n.startswith("__"):
83+
continue
84+
if n in {"machar"}:
85+
continue
86+
v1 = getattr(fi1, n)
87+
with self.subTest(att=n):
88+
v2 = getattr(fi2, n)
89+
v3 = getattr(fi3, n)
90+
if isfunction(v1) or ismethod(v1):
91+
try:
92+
v1 = v1()
93+
except TypeError:
94+
continue
95+
v2 = v2()
96+
v3 = v3()
97+
if v1 != v2:
98+
raise AssertionError(
99+
f"12: info disagree on name {n!r}: {v1} != {v2}, "
100+
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
101+
f"ismethod={ismethod(v1)}."
102+
)
103+
if v2 != v3:
104+
raise AssertionError(
105+
f"23: info disagree on name {n!r}: {v2} != {v3}, "
106+
f"type(v1)={type(v1)}, type(v2)={type(v2)}, "
107+
f"ismethod={ismethod(v1)}."
108+
)
109+
110+
111+
if __name__ == "__main__":
112+
unittest.main(verbosity=2)

_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import numpy as np
33
from onnx_array_api.ext_test_case import ExtTestCase
44
from onnx_array_api.array_api import onnx_numpy as xp
5-
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor
5+
from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor as EagerTensor
66

77

88
class TestOnnxNumpy(ExtTestCase):
99
def test_abs(self):
10-
c = EagerNumpyTensor(np.array([4, 5], dtype=np.int64))
10+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
1111
mat = xp.zeros(c, dtype=xp.int64)
1212
matnp = mat.numpy()
1313
self.assertEqual(matnp.shape, (4, 5))
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import unittest
2+
import numpy as np
3+
from onnx_array_api.ext_test_case import ExtTestCase
4+
from onnx_array_api.array_api import onnx_ort as xp
5+
from onnx_array_api.ort.ort_tensors import EagerOrtTensor as EagerTensor
6+
7+
8+
class TestOnnxOrt(ExtTestCase):
9+
def test_abs(self):
10+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
11+
mat = xp.zeros(c, dtype=xp.int64)
12+
matnp = mat.numpy()
13+
self.assertEqual(matnp.shape, (4, 5))
14+
self.assertNotEmpty(matnp[0, 0])
15+
a = xp.absolute(mat)
16+
self.assertEqualArray(np.absolute(mat.numpy()), a.numpy())
17+
18+
19+
if __name__ == "__main__":
20+
unittest.main(verbosity=2)

_unittests/ut_ort/test_ort_tensor.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,17 @@
11
import unittest
22
from contextlib import redirect_stdout
33
from io import StringIO
4-
54
import numpy as np
65
from onnx.defs import onnx_opset_version
76
from onnx.reference import ReferenceEvaluator
87
from onnxruntime import InferenceSession
9-
108
from onnx_array_api.ext_test_case import ExtTestCase
119
from onnx_array_api.npx import eager_onnx, jit_onnx
1210
from onnx_array_api.npx.npx_functions import absolute as absolute_inline
1311
from onnx_array_api.npx.npx_functions import cdist as cdist_inline
1412
from onnx_array_api.npx.npx_functions_test import absolute
15-
from onnx_array_api.npx.npx_types import Float32, Float64
13+
from onnx_array_api.npx.npx_functions import copy as copy_inline
14+
from onnx_array_api.npx.npx_types import Float32, Float64, DType
1615
from onnx_array_api.npx.npx_var import Input
1716
from onnx_array_api.ort.ort_tensors import EagerOrtTensor, JitOrtTensor, OrtTensor
1817

@@ -193,6 +192,49 @@ def impl(xa, xb):
193192
if len(pieces) > 2:
194193
raise AssertionError(f"Function is not using argument:\n{onx}")
195194

195+
def test_astype(self):
196+
f = absolute_inline(copy_inline(Input("A")).astype(np.float32))
197+
onx = f.to_onnx(constraints={"A": Float64[None]})
198+
x = np.array([[-5, 6]], dtype=np.float64)
199+
z = np.abs(x.astype(np.float32))
200+
ref = InferenceSession(
201+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
202+
)
203+
got = ref.run(None, {"A": x})
204+
self.assertEqualArray(z, got[0])
205+
206+
def test_astype0(self):
207+
f = absolute_inline(copy_inline(Input("A")).astype(np.float32))
208+
onx = f.to_onnx(constraints={"A": Float64[None]})
209+
x = np.array(-5, dtype=np.float64)
210+
z = np.abs(x.astype(np.float32))
211+
ref = InferenceSession(
212+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
213+
)
214+
got = ref.run(None, {"A": x})
215+
self.assertEqualArray(z, got[0])
216+
217+
def test_eager_ort_cast(self):
218+
def impl(A):
219+
return A.astype(DType("FLOAT"))
220+
221+
e = eager_onnx(impl)
222+
self.assertEqual(len(e.versions), 0)
223+
224+
# Float64
225+
x = np.array([0, 1, -2], dtype=np.float64)
226+
z = x.astype(np.float32)
227+
res = e(x)
228+
self.assertEqualArray(z, res)
229+
self.assertEqual(res.dtype, np.float32)
230+
231+
# again
232+
x = np.array(1, dtype=np.float64)
233+
z = x.astype(np.float32)
234+
res = e(x)
235+
self.assertEqualArray(z, res)
236+
self.assertEqual(res.dtype, np.float32)
237+
196238

197239
if __name__ == "__main__":
198240
# TestNpx().test_eager_numpy()

azure-pipelines.yml

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ jobs:
110110
displayName: 'Install tools'
111111
- script: pip install -r requirements.txt
112112
displayName: 'Install Requirements'
113+
- script: pip install onnxruntime
114+
displayName: 'Install onnxruntime'
113115
- script: python setup.py install
114116
displayName: 'Install onnx_array_api'
115117
- script: |
@@ -129,8 +131,13 @@ jobs:
129131
- script: |
130132
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
131133
cd array-api-tests
132-
python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros
133-
displayName: "test_creation_functions.py::test_zeros"
134+
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt -v
135+
displayName: "numpy test_creation_functions.py"
136+
- script: |
137+
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort
138+
cd array-api-tests
139+
python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt -v
140+
displayName: "ort test_creation_functions.py"
134141
#- script: |
135142
# export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
136143
# cd array-api-tests
@@ -246,16 +253,8 @@ jobs:
246253
displayName: 'export'
247254
- script: gcc --version
248255
displayName: 'gcc version'
249-
- script: brew install llvm
250-
displayName: 'install llvm'
251-
- script: brew install libomp
252-
displayName: 'Install omp'
253-
- script: brew install p7zip
254-
displayName: 'Install p7zip'
255256
- script: python -m pip install --upgrade pip setuptools wheel
256257
displayName: 'Install tools'
257-
- script: brew install pybind11
258-
displayName: 'Install pybind11'
259258
- script: pip install -r requirements.txt
260259
displayName: 'Install Requirements'
261260
- script: pip install -r requirements-dev.txt

onnx_array_api/_helpers.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import numpy as np
2+
from typing import Any
3+
from onnx import helper, TensorProto
4+
5+
6+
def np_dtype_to_tensor_dtype(dtype: Any):
7+
"""
8+
Improves :func:`onnx.helper.np_dtype_to_tensor_dtype`.
9+
"""
10+
try:
11+
dt = helper.np_dtype_to_tensor_dtype(dtype)
12+
except KeyError:
13+
if dtype == np.float32:
14+
dt = TensorProto.FLOAT
15+
elif dtype == np.float64:
16+
dt = TensorProto.DOUBLE
17+
elif dtype == np.int64:
18+
dt = TensorProto.INT64
19+
elif dtype == np.int32:
20+
dt = TensorProto.INT32
21+
elif dtype == np.int16:
22+
dt = TensorProto.INT16
23+
elif dtype == np.int8:
24+
dt = TensorProto.INT8
25+
elif dtype == np.uint64:
26+
dt = TensorProto.UINT64
27+
elif dtype == np.uint32:
28+
dt = TensorProto.UINT32
29+
elif dtype == np.uint16:
30+
dt = TensorProto.UINT16
31+
elif dtype == np.uint8:
32+
dt = TensorProto.UINT8
33+
elif dtype == np.float16:
34+
dt = TensorProto.FLOAT16
35+
elif dtype in (bool, np.bool_):
36+
dt = TensorProto.BOOL
37+
elif dtype in (str, np.str_):
38+
dt = TensorProto.STRING
39+
elif dtype is int:
40+
dt = TensorProto.INT64
41+
elif dtype is float:
42+
dt = TensorProto.FLOAT64
43+
else:
44+
raise KeyError(f"Unable to guess type for dtype={dtype}.")
45+
return dt

onnx_array_api/array_api/__init__.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,42 @@
1+
import numpy as np
12
from onnx import TensorProto
3+
from .._helpers import np_dtype_to_tensor_dtype
24
from ..npx.npx_types import DType
35

46

7+
def _finfo(dtype):
8+
"""
9+
Similar to :class:`numpy.finfo`.
10+
"""
11+
dt = dtype.np_dtype if isinstance(dtype, DType) else dtype
12+
res = np.finfo(dt)
13+
d = res.__dict__.copy()
14+
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
15+
nres = type("finfo", (res.__class__,), d)
16+
setattr(nres, "smallest_normal", res.smallest_normal)
17+
setattr(nres, "tiny", res.tiny)
18+
return nres
19+
20+
21+
def _iinfo(dtype):
22+
"""
23+
Similar to :class:`numpy.finfo`.
24+
"""
25+
dt = dtype.np_dtype if isinstance(dtype, DType) else dtype
26+
res = np.iinfo(dt)
27+
d = res.__dict__.copy()
28+
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
29+
nres = type("finfo", (res.__class__,), d)
30+
setattr(nres, "min", res.min)
31+
setattr(nres, "max", res.max)
32+
return nres
33+
34+
535
def _finalize_array_api(module):
36+
"""
37+
Adds common attributes to Array API defined in this modules
38+
such as types.
39+
"""
640
module.float16 = DType(TensorProto.FLOAT16)
741
module.float32 = DType(TensorProto.FLOAT)
842
module.float64 = DType(TensorProto.DOUBLE)
@@ -17,3 +51,5 @@ def _finalize_array_api(module):
1751
module.bfloat16 = DType(TensorProto.BFLOAT16)
1852
setattr(module, "bool", DType(TensorProto.BOOL))
1953
setattr(module, "str", DType(TensorProto.STRING))
54+
setattr(module, "finfo", _finfo)
55+
setattr(module, "iinfo", _iinfo)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy