Skip to content

Commit c82f9f3

Browse files
authored
Supports function full for the Array API (#21)
* Supports function full for the Array API * improvments * fix keys by adding types * fix unit tests * ci
1 parent ce37364 commit c82f9f3

17 files changed

+175
-44
lines changed

_unittests/onnx-numpy-skips.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays
55
array_api_tests/test_creation_functions.py::test_empty
66
array_api_tests/test_creation_functions.py::test_empty_like
77
array_api_tests/test_creation_functions.py::test_eye
8-
array_api_tests/test_creation_functions.py::test_full
98
array_api_tests/test_creation_functions.py::test_full_like
109
array_api_tests/test_creation_functions.py::test_linspace
1110
array_api_tests/test_creation_functions.py::test_meshgrid

_unittests/test_array_api.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_arrays || exit 1
2+
pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_scalars || exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
44
pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1

_unittests/ut_array_api/test_array_apis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TestArraysApis(ExtTestCase):
1313
def test_zeros_numpy_1(self):
1414
c = xpn.zeros(1)
1515
d = c.numpy()
16-
self.assertEqualArray(np.array([0], dtype=np.float32), d)
16+
self.assertEqualArray(np.array([0], dtype=np.float64), d)
1717

1818
def test_zeros_ort_1(self):
1919
c = xpo.zeros(1)

_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,40 @@ def test_zeros(self):
1919
a = xp.absolute(mat)
2020
self.assertEqualArray(np.absolute(mat.numpy()), a.numpy())
2121

22+
def test_zeros_none(self):
23+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
24+
mat = xp.zeros(c)
25+
matnp = mat.numpy()
26+
self.assertEqual(matnp.shape, (4, 5))
27+
self.assertNotEmpty(matnp[0, 0])
28+
self.assertEqualArray(matnp, np.zeros((4, 5)))
29+
30+
def test_ones_none(self):
31+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
32+
mat = xp.ones(c)
33+
matnp = mat.numpy()
34+
self.assertEqual(matnp.shape, (4, 5))
35+
self.assertNotEmpty(matnp[0, 0])
36+
self.assertEqualArray(matnp, np.ones((4, 5)))
37+
38+
def test_full(self):
39+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
40+
mat = xp.full(c, fill_value=5, dtype=xp.int64)
41+
matnp = mat.numpy()
42+
self.assertEqual(matnp.shape, (4, 5))
43+
self.assertNotEmpty(matnp[0, 0])
44+
a = xp.absolute(mat)
45+
self.assertEqualArray(np.absolute(mat.numpy()), a.numpy())
46+
47+
def test_full_bool(self):
48+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
49+
mat = xp.full(c, fill_value=False)
50+
matnp = mat.numpy()
51+
self.assertEqual(matnp.shape, (4, 5))
52+
self.assertNotEmpty(matnp[0, 0])
53+
self.assertEqualArray(matnp, np.full((4, 5), False))
54+
2255

2356
if __name__ == "__main__":
57+
TestOnnxNumpy().test_zeros_none()
2458
unittest.main(verbosity=2)

_unittests/ut_npx/test_npx.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,8 +710,8 @@ def impl(
710710
keys = list(sorted(f.onxs))
711711
self.assertIsInstance(f.onxs[keys[0]], ModelProto)
712712
k = keys[-1]
713-
self.assertEqual(len(k), 3)
714-
self.assertEqual(k[1:], ("axis", 0))
713+
self.assertEqual(len(k), 4)
714+
self.assertEqual(k[1:], ("axis", int, 0))
715715

716716
def test_numpy_topk(self):
717717
f = topk(Input("X"), Input("K"))
@@ -2416,6 +2416,7 @@ def compute_labels(X, centers, use_sqrt=False):
24162416
(DType(TensorProto.DOUBLE), 2),
24172417
(DType(TensorProto.DOUBLE), 2),
24182418
"use_sqrt",
2419+
bool,
24192420
True,
24202421
)
24212422
self.assertEqual(f.available_versions, [key])

azure-pipelines.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
vmImage: 'ubuntu-latest'
4949
strategy:
5050
matrix:
51-
Python310-Linux:
51+
Python311-Linux:
5252
python.version: '3.11'
5353
maxParallel: 3
5454

@@ -96,7 +96,7 @@ jobs:
9696
strategy:
9797
matrix:
9898
Python310-Linux:
99-
python.version: '3.11'
99+
python.version: '3.10'
100100
maxParallel: 3
101101

102102
steps:
@@ -149,7 +149,7 @@ jobs:
149149
vmImage: 'ubuntu-latest'
150150
strategy:
151151
matrix:
152-
Python310-Linux:
152+
Python311-Linux:
153153
python.version: '3.11'
154154
maxParallel: 3
155155

@@ -202,7 +202,7 @@ jobs:
202202
vmImage: 'windows-latest'
203203
strategy:
204204
matrix:
205-
Python310-Windows:
205+
Python311-Windows:
206206
python.version: '3.11'
207207
maxParallel: 3
208208

@@ -235,7 +235,7 @@ jobs:
235235
vmImage: 'macOS-latest'
236236
strategy:
237237
matrix:
238-
Python310-Mac:
238+
Python311-Mac:
239239
python.version: '3.11'
240240
maxParallel: 3
241241

onnx_array_api/_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def np_dtype_to_tensor_dtype(dtype: Any):
3939
elif dtype is int:
4040
dt = TensorProto.INT64
4141
elif dtype is float:
42-
dt = TensorProto.FLOAT64
42+
dt = TensorProto.DOUBLE
4343
else:
4444
raise KeyError(f"Unable to guess type for dtype={dtype}.")
4545
return dt

onnx_array_api/array_api/_onnx_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def template_asarray(
4444
except OverflowError:
4545
v = TEagerTensor(np.asarray(a, dtype=np.uint64))
4646
elif isinstance(a, float):
47-
v = TEagerTensor(np.array(a, dtype=np.float32))
47+
v = TEagerTensor(np.array(a, dtype=np.float64))
4848
elif isinstance(a, bool):
4949
v = TEagerTensor(np.array(a, dtype=np.bool_))
5050
elif isinstance(a, str):

onnx_array_api/array_api/onnx_numpy.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44
from typing import Any, Optional
55
import numpy as np
6-
from onnx import TensorProto
76
from ..npx.npx_functions import (
87
all,
98
abs,
@@ -16,10 +15,11 @@
1615
reshape,
1716
take,
1817
)
18+
from ..npx.npx_functions import full as generic_full
1919
from ..npx.npx_functions import ones as generic_ones
2020
from ..npx.npx_functions import zeros as generic_zeros
2121
from ..npx.npx_numpy_tensors import EagerNumpyTensor
22-
from ..npx.npx_types import DType, ElemType, TensorType, OptParType
22+
from ..npx.npx_types import DType, ElemType, TensorType, OptParType, ParType, Scalar
2323
from ._onnx_common import template_asarray
2424
from . import _finalize_array_api
2525

@@ -31,6 +31,7 @@
3131
"astype",
3232
"empty",
3333
"equal",
34+
"full",
3435
"isdtype",
3536
"isfinite",
3637
"isnan",
@@ -58,7 +59,7 @@ def asarray(
5859

5960
def ones(
6061
shape: TensorType[ElemType.int64, "I", (None,)],
61-
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
62+
dtype: OptParType[DType] = None,
6263
order: OptParType[str] = "C",
6364
) -> TensorType[ElemType.numerics, "T"]:
6465
if isinstance(shape, tuple):
@@ -76,7 +77,7 @@ def ones(
7677

7778
def empty(
7879
shape: TensorType[ElemType.int64, "I", (None,)],
79-
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
80+
dtype: OptParType[DType] = None,
8081
order: OptParType[str] = "C",
8182
) -> TensorType[ElemType.numerics, "T"]:
8283
raise RuntimeError(
@@ -87,7 +88,7 @@ def empty(
8788

8889
def zeros(
8990
shape: TensorType[ElemType.int64, "I", (None,)],
90-
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
91+
dtype: OptParType[DType] = None,
9192
order: OptParType[str] = "C",
9293
) -> TensorType[ElemType.numerics, "T"]:
9394
if isinstance(shape, tuple):
@@ -103,6 +104,32 @@ def zeros(
103104
return generic_zeros(shape, dtype=dtype, order=order)
104105

105106

107+
def full(
108+
shape: TensorType[ElemType.int64, "I", (None,)],
109+
fill_value: ParType[Scalar] = None,
110+
dtype: OptParType[DType] = None,
111+
order: OptParType[str] = "C",
112+
) -> TensorType[ElemType.numerics, "T"]:
113+
if fill_value is None:
114+
raise TypeError("fill_value cannot be None")
115+
value = fill_value
116+
if isinstance(shape, tuple):
117+
return generic_full(
118+
EagerNumpyTensor(np.array(shape, dtype=np.int64)),
119+
fill_value=value,
120+
dtype=dtype,
121+
order=order,
122+
)
123+
if isinstance(shape, int):
124+
return generic_full(
125+
EagerNumpyTensor(np.array([shape], dtype=np.int64)),
126+
fill_value=value,
127+
dtype=dtype,
128+
order=order,
129+
)
130+
return generic_full(shape, fill_value=value, dtype=dtype, order=order)
131+
132+
106133
def _finalize():
107134
"""
108135
Adds common attributes to Array API defined in this modules

onnx_array_api/npx/npx_core_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def wrapper(*inputs, **kwargs):
169169
new_inputs.append(i)
170170
elif isinstance(i, (int, float)):
171171
new_inputs.append(
172-
np.array([i], dtype=np.int64 if isinstance(i, int) else np.float32)
172+
np.array([i], dtype=np.int64 if isinstance(i, int) else np.float64)
173173
)
174174
elif isinstance(i, str):
175175
new_inputs.append(Input(i))

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy