Skip to content

Commit c6a3718

Browse files
authored
Fixes asarray for the Array API (#25)
* Fixes asarray for the Array API * move
1 parent 61eec9d commit c6a3718

File tree

7 files changed

+77
-17
lines changed

7 files changed

+77
-17
lines changed

_unittests/ut_array_api/test_hypothesis_array_api.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from os import getenv
44
from functools import reduce
5+
import numpy as np
56
from operator import mul
67
from hypothesis import given
78
from onnx_array_api.ext_test_case import ExtTestCase
@@ -89,24 +90,49 @@ def test_scalar_strategies(self):
8990

9091
args_np = []
9192

93+
xx = self.xps.arrays(dtype=dtypes["integer_dtypes"], shape=shapes(self.xps))
94+
kws = array_api_kwargs(dtype=strategies.none() | self.xps.scalar_dtypes())
95+
9296
@given(
93-
x=self.xps.arrays(dtype=dtypes["integer_dtypes"], shape=shapes(self.xps)),
94-
kw=array_api_kwargs(dtype=strategies.none() | self.xps.scalar_dtypes()),
97+
x=xx,
98+
kw=kws,
9599
)
96-
def fct(x, kw):
100+
def fctnp(x, kw):
101+
asa1 = np.asarray(x)
102+
asa2 = np.asarray(x, **kw)
103+
self.assertEqual(asa1.shape, asa2.shape)
97104
args_np.append((x, kw))
98105

99-
fct()
106+
fctnp()
100107
self.assertEqual(len(args_np), 100)
101108

102109
args_onxp = []
103110

104111
xshape = shapes(self.onxps)
105112
xx = self.onxps.arrays(dtype=dtypes_onnx["integer_dtypes"], shape=xshape)
106-
kw = array_api_kwargs(dtype=strategies.none() | self.onxps.scalar_dtypes())
113+
kws = array_api_kwargs(dtype=strategies.none() | self.onxps.scalar_dtypes())
107114

108-
@given(x=xx, kw=kw)
115+
@given(x=xx, kw=kws)
109116
def fctonx(x, kw):
117+
asa = np.asarray(x.numpy())
118+
try:
119+
asp = onxp.asarray(x)
120+
except Exception as e:
121+
raise AssertionError(f"asarray fails with x={x!r}, asp={asa!r}.") from e
122+
try:
123+
self.assertEqualArray(asa, asp.numpy())
124+
except AssertionError as e:
125+
raise AssertionError(
126+
f"x={x!r} kw={kw!r} asa={asa!r}, asp={asp!r}"
127+
) from e
128+
if kw:
129+
try:
130+
asp2 = onxp.asarray(x, **kw)
131+
except Exception as e:
132+
raise AssertionError(
133+
f"asarray fails with x={x!r}, kw={kw!r}, asp={asa!r}."
134+
) from e
135+
self.assertEqual(asp.shape, asp2.shape)
110136
args_onxp.append((x, kw))
111137

112138
fctonx()

onnx_array_api/array_api/_onnx_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from typing import Any, Optional
2+
import warnings
23
import numpy as np
4+
5+
with warnings.catch_warnings():
6+
warnings.simplefilter("ignore")
7+
from numpy.array_api._array_object import Array
38
from ..npx.npx_types import (
49
DType,
510
ElemType,
@@ -77,6 +82,10 @@ def asarray(
7782
v = TEagerTensor(np.array(a, dtype=np.str_))
7883
elif isinstance(a, list):
7984
v = TEagerTensor(np.array(a))
85+
elif isinstance(a, np.ndarray):
86+
v = TEagerTensor(a)
87+
elif isinstance(a, Array):
88+
v = TEagerTensor(np.asarray(a))
8089
else:
8190
raise RuntimeError(f"Unexpected type {type(a)} for the first input.")
8291
if dtype is not None:

onnx_array_api/npx/npx_numpy_tensors.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Any, Callable, List, Optional, Tuple
23
import numpy as np
34
from onnx import ModelProto, TensorProto
@@ -221,13 +222,18 @@ def __bool__(self):
221222
if self.shape == (0,):
222223
return False
223224
if len(self.shape) != 0:
224-
raise ValueError(
225-
f"Conversion to bool only works for scalar, not for {self!r}."
225+
warnings.warn(
226+
f"Conversion to bool only works for scalar, not for {self!r}, "
227+
f"bool(...)={bool(self._tensor)}."
226228
)
229+
try:
230+
return bool(self._tensor)
231+
except ValueError as e:
232+
raise ValueError(f"Unable to convert {self} to bool.") from e
227233
return bool(self._tensor)
228234

229235
def __int__(self):
230-
"Implicit conversion to bool."
236+
"Implicit conversion to int."
231237
if len(self.shape) != 0:
232238
raise ValueError(
233239
f"Conversion to bool only works for scalar, not for {self!r}."
@@ -249,7 +255,7 @@ def __int__(self):
249255
return int(self._tensor)
250256

251257
def __float__(self):
252-
"Implicit conversion to bool."
258+
"Implicit conversion to float."
253259
if len(self.shape) != 0:
254260
raise ValueError(
255261
f"Conversion to bool only works for scalar, not for {self!r}."
@@ -261,11 +267,24 @@ def __float__(self):
261267
DType(TensorProto.BFLOAT16),
262268
}:
263269
raise TypeError(
264-
f"Conversion to int only works for float scalar, "
270+
f"Conversion to float only works for float scalar, "
265271
f"not for dtype={self.dtype}."
266272
)
267273
return float(self._tensor)
268274

275+
def __iter__(self):
276+
"""
277+
The :epkg:`Array API` does not define this function (2022/12).
278+
This method raises an exception with a better error message.
279+
"""
280+
warnings.warn(
281+
f"Iterators are not implemented in the generic case. "
282+
f"Every function using them cannot be converted into ONNX "
283+
f"(tensors - {type(self)})."
284+
)
285+
for row in self._tensor:
286+
yield self.__class__(row)
287+
269288

270289
class JitNumpyTensor(NumpyTensor, JitTensor):
271290
"""

onnx_array_api/npx/npx_tensors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ def __iter__(self):
3535
This method raises an exception with a better error message.
3636
"""
3737
raise ArrayApiError(
38-
"Iterators are not implemented in the generic case. "
39-
"Every function using them cannot be converted into ONNX."
38+
f"Iterators are not implemented in the generic case. "
39+
f"Every function using them cannot be converted into ONNX "
40+
f"(tensors - {type(self)})."
4041
)
4142

4243
@staticmethod

onnx_array_api/npx/npx_types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,16 @@ def __eq__(self, dt: "DType") -> bool:
5959
return False
6060
if dt.__class__ is DType:
6161
return self.code_ == dt.code_
62-
if isinstance(dt, (int, bool, str)):
62+
if isinstance(dt, (int, bool, str, float)):
6363
return False
64+
if dt is int:
65+
return self.code_ == TensorProto.INT64
6466
if dt is str:
6567
return self.code_ == TensorProto.STRING
6668
if dt is bool:
6769
return self.code_ == TensorProto.BOOL
70+
if dt is float:
71+
return self.code_ == TensorProto.FLOAT64
6872
if isinstance(dt, list):
6973
return False
7074
if dt in ElemType.numpy_map:

onnx_array_api/npx/npx_var.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,9 @@ def __iter__(self):
607607
This method raises an exception with a better error message.
608608
"""
609609
raise ArrayApiError(
610-
"Iterators are not implemented in the generic case. "
611-
"Every function using them cannot be converted into ONNX."
610+
f"Iterators are not implemented in the generic case. "
611+
f"Every function using them cannot be converted into ONNX "
612+
f"(Var - {type(self)})."
612613
)
613614

614615
def _binary_op(self, ov: "Var", op_name: str, **kwargs) -> "Var":

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ black
33
coverage
44
flake8
55
furo
6-
hypothesis<6.80.0
6+
hypothesis
77
isort
88
joblib
99
lightgbm

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy