From 791eb37fa9fc018abb9a89c4464bec65f5e9516d Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 18 Jun 2023 10:46:25 +0200 Subject: [PATCH 1/5] Supports function full for the Array API --- _unittests/onnx-numpy-skips.txt | 2 +- _unittests/test_array_api.sh | 2 +- _unittests/ut_array_api/test_onnx_numpy.py | 10 +++++++ onnx_array_api/array_api/onnx_numpy.py | 30 ++++++++++++++++++++- onnx_array_api/npx/npx_functions.py | 30 ++++++++++++++++++++- onnx_array_api/npx/npx_graph_builder.py | 2 +- onnx_array_api/npx/npx_jit_eager.py | 14 +++++----- onnx_array_api/npx/npx_numpy_tensors_ops.py | 2 ++ onnx_array_api/npx/npx_tensors.py | 2 +- onnx_array_api/npx/npx_types.py | 28 ++++++++++++++++--- onnx_array_api/npx/npx_var.py | 16 ++++++----- 11 files changed, 117 insertions(+), 21 deletions(-) diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index eef3e70..3beafc6 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -5,7 +5,7 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_creation_functions.py::test_empty array_api_tests/test_creation_functions.py::test_empty_like array_api_tests/test_creation_functions.py::test_eye -array_api_tests/test_creation_functions.py::test_full +# array_api_tests/test_creation_functions.py::test_full array_api_tests/test_creation_functions.py::test_full_like array_api_tests/test_creation_functions.py::test_linspace array_api_tests/test_creation_functions.py::test_meshgrid diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh index cb32fe4..c75a61b 100644 --- a/_unittests/test_array_api.sh +++ b/_unittests/test_array_api.sh @@ -1,4 +1,4 @@ export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy -# pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_arrays || exit 1 +pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_full || exit 1 # pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 \ No newline at end of file diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index bd79ecf..100ed2a 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -19,6 +19,16 @@ def test_zeros(self): a = xp.absolute(mat) self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) + def test_full(self): + c = EagerTensor(np.array([4, 5], dtype=np.int64)) + mat = xp.full(c, fill_value=5, dtype=xp.int64) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (4, 5)) + self.assertNotEmpty(matnp[0, 0]) + a = xp.absolute(mat) + self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) + if __name__ == "__main__": + TestOnnxNumpy().test_full() unittest.main(verbosity=2) diff --git a/onnx_array_api/array_api/onnx_numpy.py b/onnx_array_api/array_api/onnx_numpy.py index 2cd4bfd..4825bd6 100644 --- a/onnx_array_api/array_api/onnx_numpy.py +++ b/onnx_array_api/array_api/onnx_numpy.py @@ -16,10 +16,11 @@ reshape, take, ) +from ..npx.npx_functions import full as generic_full from ..npx.npx_functions import ones as generic_ones from ..npx.npx_functions import zeros as generic_zeros from ..npx.npx_numpy_tensors import EagerNumpyTensor -from ..npx.npx_types import DType, ElemType, TensorType, OptParType +from ..npx.npx_types import DType, ElemType, TensorType, OptParType, ParType, Scalar from ._onnx_common import template_asarray from . import _finalize_array_api @@ -31,6 +32,7 @@ "astype", "empty", "equal", + "full", "isdtype", "isfinite", "isnan", @@ -103,6 +105,32 @@ def zeros( return generic_zeros(shape, dtype=dtype, order=order) +def full( + shape: TensorType[ElemType.int64, "I", (None,)], + fill_value: ParType[Scalar] = None, + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + if fill_value is None: + raise AttributeError("fill_value cannot be None") + value = fill_value + if isinstance(shape, tuple): + return generic_full( + EagerNumpyTensor(np.array(shape, dtype=np.int64)), + fill_value=value, + dtype=dtype, + order=order, + ) + if isinstance(shape, int): + return generic_full( + EagerNumpyTensor(np.array([shape], dtype=np.int64)), + fill_value=value, + dtype=dtype, + order=order, + ) + return generic_full(shape, fill_value=value, dtype=dtype, order=order) + + def _finalize(): """ Adds common attributes to Array API defined in this modules diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index 29a4481..c223f0d 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -15,6 +15,7 @@ SequenceType, TensorType, TupleType, + Scalar, ) from .npx_var import Var @@ -22,7 +23,7 @@ def _cstv(x): if isinstance(x, Var): return x - if isinstance(x, (int, float, np.ndarray)): + if isinstance(x, (int, float, bool, np.ndarray)): return cst(x) raise TypeError(f"Unexpected constant type {type(x)}.") @@ -376,6 +377,33 @@ def expit(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics return var(x, op="Sigmoid") +@npxapi_inline +def full( + shape: TensorType[ElemType.int64, "I", (None,)], + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + fill_value: ParType[Scalar] = None, + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + """ + Implements :func:`numpy.zeros`. + """ + if order != "C": + raise RuntimeError(f"order={order!r} != 'C' not supported.") + if fill_value is None: + raise AttributeError("fill_value cannot be None.") + if dtype is None: + dtype = DType(TensorProto.FLOAT) + if isinstance(fill_value, (float, int, bool)): + value = make_tensor( + name="cst", data_type=dtype.code, dims=[1], vals=[fill_value] + ) + else: + raise NotImplementedError( + f"Unexpected type ({type(fill_value)} for fill_value={fill_value!r}." + ) + return var(shape, value=value, op="ConstantOfShape") + + @npxapi_inline def floor(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]: "See :func:`numpy.floor`." diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index d41b91c..ff02843 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -784,7 +784,7 @@ def to_onnx( node_inputs.append(input_name) continue - if isinstance(i, (int, float)): + if isinstance(i, (int, float, bool)): ni = np.array(i) c = Cst(ni) input_name = self._unique(var._prefix) diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index 35ff9af..5f30d30 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -131,7 +131,7 @@ def make_key(*values, **kwargs): for iv, v in enumerate(values): if isinstance(v, (Var, EagerTensor, JitTensor)): res.append(v.key) - elif isinstance(v, (int, float, DType)): + elif isinstance(v, (int, float, bool, DType)): res.append(v) elif isinstance(v, slice): res.append(("slice", v.start, v.stop, v.step)) @@ -153,7 +153,7 @@ def make_key(*values, **kwargs): ) if kwargs: for k, v in sorted(kwargs.items()): - if isinstance(v, (int, float, str, type, DType)): + if isinstance(v, (int, float, str, type, bool, DType)): res.append(k) res.append(v) elif isinstance(v, tuple): @@ -543,12 +543,12 @@ def _preprocess_constants(self, *args): elif isinstance(n, np.ndarray): new_args.append(self.tensor_class(n)) modified = True - elif isinstance(n, (int, float)): + elif isinstance(n, (int, float, bool)): new_args.append(self.tensor_class(np.array(n))) modified = True elif isinstance(n, DType): new_args.append(n) - elif n in (int, float): + elif n in (int, float, bool): # usually used to cast new_args.append(n) elif n is None: @@ -586,6 +586,7 @@ def __call__(self, *args, already_eager=False, **kwargs): EagerTensor, Cst, int, + bool, float, tuple, slice, @@ -616,12 +617,13 @@ def __call__(self, *args, already_eager=False, **kwargs): else: # tries to call the version try: - res = self.f(*values) + res = self.f(*values, **kwargs) except (AttributeError, TypeError) as e: inp1 = ", ".join(map(str, map(type, args))) inp2 = ", ".join(map(str, map(type, values))) raise TypeError( - f"Unexpected types, input types are {inp1} " f"and {inp2}." + f"Unexpected types, input types are {inp1} " + f"and {inp2}, kwargs={kwargs}." ) from e if isinstance(res, EagerTensor) or ( diff --git a/onnx_array_api/npx/npx_numpy_tensors_ops.py b/onnx_array_api/npx/npx_numpy_tensors_ops.py index 5278019..c9cae2f 100644 --- a/onnx_array_api/npx/npx_numpy_tensors_ops.py +++ b/onnx_array_api/npx/npx_numpy_tensors_ops.py @@ -11,6 +11,8 @@ def _process(value): cst = np.int64(cst) elif isinstance(cst, float): cst = np.float64(cst) + elif isinstance(cst, bool): + cst = np.bool_(cst) elif cst is None: cst = np.float32(0) if not isinstance( diff --git a/onnx_array_api/npx/npx_tensors.py b/onnx_array_api/npx/npx_tensors.py index b0e92c2..9286ae2 100644 --- a/onnx_array_api/npx/npx_tensors.py +++ b/onnx_array_api/npx/npx_tensors.py @@ -133,7 +133,7 @@ def _generic_method_operator(self, method_name, *args: Any, **kwargs: Any) -> An for a in args: if isinstance(a, np.ndarray): new_args.append(self.__class__(a.astype(self.dtype.np_dtype))) - elif isinstance(a, (int, float)): + elif isinstance(a, (int, float, bool)): new_args.append( self.__class__(np.array([a]).astype(self.dtype.np_dtype)) ) diff --git a/onnx_array_api/npx/npx_types.py b/onnx_array_api/npx/npx_types.py index 6063e64..0f7f6dc 100644 --- a/onnx_array_api/npx/npx_types.py +++ b/onnx_array_api/npx/npx_types.py @@ -292,6 +292,19 @@ def get_set_name(cls, dtypes): return None +class Scalar: + """ + Defines a scalar. + """ + + def __init__(self, value: Union[float, int, bool]): + self.value = value + + def __repr__(self): + "usual" + return f"Scalar({self.value!r})" + + class ParType(WrapperType): """ Defines a parameter type. @@ -300,11 +313,18 @@ class ParType(WrapperType): :param optional: is optional or not """ - map_names = {int: "int", float: "float", str: "str", DType: "DType"} + map_names = { + int: "int", + float: "float", + str: "str", + DType: "DType", + bool: "bool", + Scalar: "Scalar", + } @classmethod def __class_getitem__(cls, dtype): - if isinstance(dtype, (int, float)): + if isinstance(dtype, (int, float, bool)): msg = str(dtype) else: msg = getattr(dtype, "__name__", str(dtype)) @@ -331,6 +351,8 @@ def onnx_type(cls): return AttributeProto.INT if cls.dtype == float: return AttributeProto.FLOAT + if cls.dtype == bool: + return AttributeProto.BOOL if cls.dtype == str: return AttributeProto.STRING raise RuntimeError( @@ -347,7 +369,7 @@ class OptParType(ParType): @classmethod def __class_getitem__(cls, dtype): - if isinstance(dtype, (int, float)): + if isinstance(dtype, (int, float, bool)): msg = str(dtype) else: msg = dtype.__name__ diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index 2759f4c..3f5e090 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -12,7 +12,7 @@ class Par: Defines a named parameter. :param name: parameter name - :param dtype: parameter type (int, str, float) + :param dtype: parameter type (bool, int, str, float) :param value: value of the parameter if known :param parent_op: node type it belongs to """ @@ -233,7 +233,7 @@ def __call__(self, new_values): def _setitem1_where(self, index, new_values): cst, var = Var.get_cst_var() - if isinstance(new_values, (int, float)): + if isinstance(new_values, (int, float, bool)): new_values = np.array(new_values) if isinstance(new_values, np.ndarray): value = var(cst(new_values), self.parent, op="CastLike") @@ -446,7 +446,7 @@ def _get_vars(self): cst = Var.get_cst_var()[0] replacement_cst[id(i)] = cst(i) continue - if isinstance(i, (int, float)): + if isinstance(i, (int, float, bool)): cst = Var.get_cst_var()[0] replacement_cst[id(i)] = cst(np.array(i)) continue @@ -595,13 +595,13 @@ def __iter__(self): def _binary_op(self, ov: "Var", op_name: str, **kwargs) -> "Var": var = Var.get_cst_var()[1] - if isinstance(ov, (int, float, np.ndarray, Cst)): + if isinstance(ov, (int, float, bool, np.ndarray, Cst)): return var(self.self_var, var(ov, self.self_var, op="CastLike"), op=op_name) return var(self.self_var, ov, op=op_name, **kwargs) def _binary_op_right(self, ov: "Var", op_name: str, **kwargs) -> "Var": var = Var.get_cst_var()[1] - if isinstance(ov, (int, float, np.ndarray, Cst)): + if isinstance(ov, (int, float, bool, np.ndarray, Cst)): return var(var(ov, self.self_var, op="CastLike"), self.self_var, op=op_name) return var(ov, self.self_var, op=op_name, **kwargs) @@ -1112,10 +1112,14 @@ def __init__(self, cst: Any): Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") elif isinstance(cst, float): Var.__init__(self, np.array(cst, dtype=np.float32), op="Identity") + elif isinstance(cst, bool): + Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity") elif isinstance(cst, list): if all(map(lambda t: isinstance(t, int), cst)): Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") - elif all(map(lambda t: isinstance(t, (float, int)), cst)): + elif all(map(lambda t: isinstance(t, bool), cst)): + Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity") + elif all(map(lambda t: isinstance(t, (float, int, bool)), cst)): Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity") else: raise ValueError( From 7796b45bf67fe2ad4fa9f111ac60f73046e1ddd0 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 18 Jun 2023 11:18:31 +0200 Subject: [PATCH 2/5] improvments --- _unittests/ut_array_api/test_onnx_numpy.py | 10 +++++++++- onnx_array_api/array_api/_onnx_common.py | 2 +- onnx_array_api/array_api/onnx_numpy.py | 4 ++-- onnx_array_api/npx/npx_core_api.py | 2 +- onnx_array_api/npx/npx_functions.py | 17 +++++++++++++---- onnx_array_api/npx/npx_jit_eager.py | 1 + onnx_array_api/npx/npx_numpy_tensors_ops.py | 6 +++--- onnx_array_api/npx/npx_var.py | 12 ++++++------ 8 files changed, 36 insertions(+), 18 deletions(-) diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index 100ed2a..55b2d94 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -28,7 +28,15 @@ def test_full(self): a = xp.absolute(mat) self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) + def test_full_bool(self): + c = EagerTensor(np.array([4, 5], dtype=np.int64)) + mat = xp.full(c, fill_value=False) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (4, 5)) + self.assertNotEmpty(matnp[0, 0]) + self.assertEqualArray(matnp, np.full((4, 5), False)) + if __name__ == "__main__": - TestOnnxNumpy().test_full() + TestOnnxNumpy().test_full_bool() unittest.main(verbosity=2) diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 6553137..f832b72 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -44,7 +44,7 @@ def template_asarray( except OverflowError: v = TEagerTensor(np.asarray(a, dtype=np.uint64)) elif isinstance(a, float): - v = TEagerTensor(np.array(a, dtype=np.float32)) + v = TEagerTensor(np.array(a, dtype=np.float64)) elif isinstance(a, bool): v = TEagerTensor(np.array(a, dtype=np.bool_)) elif isinstance(a, str): diff --git a/onnx_array_api/array_api/onnx_numpy.py b/onnx_array_api/array_api/onnx_numpy.py index 4825bd6..9f50d3f 100644 --- a/onnx_array_api/array_api/onnx_numpy.py +++ b/onnx_array_api/array_api/onnx_numpy.py @@ -108,11 +108,11 @@ def zeros( def full( shape: TensorType[ElemType.int64, "I", (None,)], fill_value: ParType[Scalar] = None, - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: if fill_value is None: - raise AttributeError("fill_value cannot be None") + raise TypeError("fill_value cannot be None") value = fill_value if isinstance(shape, tuple): return generic_full( diff --git a/onnx_array_api/npx/npx_core_api.py b/onnx_array_api/npx/npx_core_api.py index 05cb0bb..548a40a 100644 --- a/onnx_array_api/npx/npx_core_api.py +++ b/onnx_array_api/npx/npx_core_api.py @@ -169,7 +169,7 @@ def wrapper(*inputs, **kwargs): new_inputs.append(i) elif isinstance(i, (int, float)): new_inputs.append( - np.array([i], dtype=np.int64 if isinstance(i, int) else np.float32) + np.array([i], dtype=np.int64 if isinstance(i, int) else np.float64) ) elif isinstance(i, str): new_inputs.append(Input(i)) diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index c223f0d..ab923b7 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -380,19 +380,28 @@ def expit(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics @npxapi_inline def full( shape: TensorType[ElemType.int64, "I", (None,)], - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, fill_value: ParType[Scalar] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: """ - Implements :func:`numpy.zeros`. + Implements :func:`numpy.full`. """ if order != "C": raise RuntimeError(f"order={order!r} != 'C' not supported.") if fill_value is None: - raise AttributeError("fill_value cannot be None.") + raise TypeError("fill_value cannot be None.") if dtype is None: - dtype = DType(TensorProto.FLOAT) + if isinstance(fill_value, bool): + dtype = DType(TensorProto.BOOL) + elif isinstance(fill_value, int): + dtype = DType(TensorProto.INT64) + elif isinstance(fill_value, float): + dtype = DType(TensorProto.DOUBLE) + else: + raise TypeError( + f"Unexpected type {type(fill_value)} for fill_value={fill_value!r}." + ) if isinstance(fill_value, (float, int, bool)): value = make_tensor( name="cst", data_type=dtype.code, dims=[1], vals=[fill_value] diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index 5f30d30..bfb87fe 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -155,6 +155,7 @@ def make_key(*values, **kwargs): for k, v in sorted(kwargs.items()): if isinstance(v, (int, float, str, type, bool, DType)): res.append(k) + res.append(type(v)) res.append(v) elif isinstance(v, tuple): newv = [] diff --git a/onnx_array_api/npx/npx_numpy_tensors_ops.py b/onnx_array_api/npx/npx_numpy_tensors_ops.py index c9cae2f..b4639ae 100644 --- a/onnx_array_api/npx/npx_numpy_tensors_ops.py +++ b/onnx_array_api/npx/npx_numpy_tensors_ops.py @@ -7,12 +7,12 @@ class ConstantOfShape(OpRun): @staticmethod def _process(value): cst = value[0] if isinstance(value, np.ndarray) else value - if isinstance(cst, int): + if isinstance(cst, bool): + cst = np.bool_(cst) + elif isinstance(cst, int): cst = np.int64(cst) elif isinstance(cst, float): cst = np.float64(cst) - elif isinstance(cst, bool): - cst = np.bool_(cst) elif cst is None: cst = np.float32(0) if not isinstance( diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index 3f5e090..a4802e3 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -1108,17 +1108,17 @@ class Cst(Var): def __init__(self, cst: Any): if isinstance(cst, np.ndarray): Var.__init__(self, cst, op="Identity") + elif isinstance(cst, bool): + Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity") elif isinstance(cst, int): Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") elif isinstance(cst, float): - Var.__init__(self, np.array(cst, dtype=np.float32), op="Identity") - elif isinstance(cst, bool): - Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity") + Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity") elif isinstance(cst, list): - if all(map(lambda t: isinstance(t, int), cst)): - Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") - elif all(map(lambda t: isinstance(t, bool), cst)): + if all(map(lambda t: isinstance(t, bool), cst)): Var.__init__(self, np.array(cst, dtype=np.bool_), op="Identity") + elif all(map(lambda t: isinstance(t, (int, bool)), cst)): + Var.__init__(self, np.array(cst, dtype=np.int64), op="Identity") elif all(map(lambda t: isinstance(t, (float, int, bool)), cst)): Var.__init__(self, np.array(cst, dtype=np.float64), op="Identity") else: From 870ba4b8990f6f73730ebfc7d15eb27c717f9b51 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 18 Jun 2023 11:38:13 +0200 Subject: [PATCH 3/5] fix keys by adding types --- _unittests/test_array_api.sh | 2 +- _unittests/ut_array_api/test_onnx_numpy.py | 18 +++++++++++++++++- _unittests/ut_npx/test_npx.py | 5 +++-- onnx_array_api/_helpers.py | 2 +- onnx_array_api/array_api/onnx_numpy.py | 7 +++---- onnx_array_api/npx/npx_functions.py | 8 ++++---- onnx_array_api/npx/npx_jit_eager.py | 4 +++- 7 files changed, 32 insertions(+), 14 deletions(-) diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh index c75a61b..1de8dfb 100644 --- a/_unittests/test_array_api.sh +++ b/_unittests/test_array_api.sh @@ -1,4 +1,4 @@ export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy -pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_full || exit 1 +pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_ones || exit 1 # pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 \ No newline at end of file diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index 55b2d94..4cb7544 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -19,6 +19,22 @@ def test_zeros(self): a = xp.absolute(mat) self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) + def test_zeros_none(self): + c = EagerTensor(np.array([4, 5], dtype=np.int64)) + mat = xp.zeros(c) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (4, 5)) + self.assertNotEmpty(matnp[0, 0]) + self.assertEqualArray(matnp, np.zeros((4, 5))) + + def test_ones_none(self): + c = EagerTensor(np.array([4, 5], dtype=np.int64)) + mat = xp.ones(c) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (4, 5)) + self.assertNotEmpty(matnp[0, 0]) + self.assertEqualArray(matnp, np.ones((4, 5))) + def test_full(self): c = EagerTensor(np.array([4, 5], dtype=np.int64)) mat = xp.full(c, fill_value=5, dtype=xp.int64) @@ -38,5 +54,5 @@ def test_full_bool(self): if __name__ == "__main__": - TestOnnxNumpy().test_full_bool() + TestOnnxNumpy().test_zeros_none() unittest.main(verbosity=2) diff --git a/_unittests/ut_npx/test_npx.py b/_unittests/ut_npx/test_npx.py index 93f2b5e..17b5863 100644 --- a/_unittests/ut_npx/test_npx.py +++ b/_unittests/ut_npx/test_npx.py @@ -710,8 +710,8 @@ def impl( keys = list(sorted(f.onxs)) self.assertIsInstance(f.onxs[keys[0]], ModelProto) k = keys[-1] - self.assertEqual(len(k), 3) - self.assertEqual(k[1:], ("axis", 0)) + self.assertEqual(len(k), 4) + self.assertEqual(k[1:], ("axis", int, 0)) def test_numpy_topk(self): f = topk(Input("X"), Input("K")) @@ -2416,6 +2416,7 @@ def compute_labels(X, centers, use_sqrt=False): (DType(TensorProto.DOUBLE), 2), (DType(TensorProto.DOUBLE), 2), "use_sqrt", + bool, True, ) self.assertEqual(f.available_versions, [key]) diff --git a/onnx_array_api/_helpers.py b/onnx_array_api/_helpers.py index 6191c92..f9808ca 100644 --- a/onnx_array_api/_helpers.py +++ b/onnx_array_api/_helpers.py @@ -39,7 +39,7 @@ def np_dtype_to_tensor_dtype(dtype: Any): elif dtype is int: dt = TensorProto.INT64 elif dtype is float: - dt = TensorProto.FLOAT64 + dt = TensorProto.DOUBLE else: raise KeyError(f"Unable to guess type for dtype={dtype}.") return dt diff --git a/onnx_array_api/array_api/onnx_numpy.py b/onnx_array_api/array_api/onnx_numpy.py index 9f50d3f..425418f 100644 --- a/onnx_array_api/array_api/onnx_numpy.py +++ b/onnx_array_api/array_api/onnx_numpy.py @@ -3,7 +3,6 @@ """ from typing import Any, Optional import numpy as np -from onnx import TensorProto from ..npx.npx_functions import ( all, abs, @@ -60,7 +59,7 @@ def asarray( def ones( shape: TensorType[ElemType.int64, "I", (None,)], - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: if isinstance(shape, tuple): @@ -78,7 +77,7 @@ def ones( def empty( shape: TensorType[ElemType.int64, "I", (None,)], - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: raise RuntimeError( @@ -89,7 +88,7 @@ def empty( def zeros( shape: TensorType[ElemType.int64, "I", (None,)], - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: if isinstance(shape, tuple): diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index ab923b7..98e37f4 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -501,7 +501,7 @@ def matmul( @npxapi_inline def ones( shape: TensorType[ElemType.int64, "I", (None,)], - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: """ @@ -510,7 +510,7 @@ def ones( if order != "C": raise RuntimeError(f"order={order!r} != 'C' not supported.") if dtype is None: - dtype = DType(TensorProto.FLOAT) + dtype = DType(TensorProto.DOUBLE) return var( shape, value=make_tensor(name="one", data_type=dtype.code, dims=[1], vals=[1]), @@ -711,7 +711,7 @@ def where( @npxapi_inline def zeros( shape: TensorType[ElemType.int64, "I", (None,)], - dtype: OptParType[DType] = DType(TensorProto.FLOAT), + dtype: OptParType[DType] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: """ @@ -720,7 +720,7 @@ def zeros( if order != "C": raise RuntimeError(f"order={order!r} != 'C' not supported.") if dtype is None: - dtype = DType(TensorProto.FLOAT) + dtype = DType(TensorProto.DOUBLE) return var( shape, value=make_tensor(name="zero", data_type=dtype.code, dims=[1], vals=[0]), diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index bfb87fe..c222f01 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -132,6 +132,7 @@ def make_key(*values, **kwargs): if isinstance(v, (Var, EagerTensor, JitTensor)): res.append(v.key) elif isinstance(v, (int, float, bool, DType)): + res.append(type(v)) res.append(v) elif isinstance(v, slice): res.append(("slice", v.start, v.stop, v.step)) @@ -170,7 +171,8 @@ def make_key(*values, **kwargs): newv.append(t) res.append(tuple(newv)) elif v is None and k in {"dtype"}: - continue + res.append(k) + res.append(v) else: raise TypeError( f"Type {type(v)} is not yet supported, " From 3f14ff10088d99e13f0700938d2e259a3717c4bd Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 18 Jun 2023 11:46:13 +0200 Subject: [PATCH 4/5] fix unit tests --- _unittests/ut_array_api/test_array_apis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_array_api/test_array_apis.py b/_unittests/ut_array_api/test_array_apis.py index c72700c..9a8dd7c 100644 --- a/_unittests/ut_array_api/test_array_apis.py +++ b/_unittests/ut_array_api/test_array_apis.py @@ -13,7 +13,7 @@ class TestArraysApis(ExtTestCase): def test_zeros_numpy_1(self): c = xpn.zeros(1) d = c.numpy() - self.assertEqualArray(np.array([0], dtype=np.float32), d) + self.assertEqualArray(np.array([0], dtype=np.float64), d) def test_zeros_ort_1(self): c = xpo.zeros(1) From b41c59564557b06b30d71a2db7e3cf928702ed96 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 18 Jun 2023 11:56:17 +0200 Subject: [PATCH 5/5] ci --- _unittests/onnx-numpy-skips.txt | 1 - _unittests/test_array_api.sh | 2 +- azure-pipelines.yml | 10 +++++----- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index 3beafc6..62de43f 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -5,7 +5,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_creation_functions.py::test_empty array_api_tests/test_creation_functions.py::test_empty_like array_api_tests/test_creation_functions.py::test_eye -# array_api_tests/test_creation_functions.py::test_full array_api_tests/test_creation_functions.py::test_full_like array_api_tests/test_creation_functions.py::test_linspace array_api_tests/test_creation_functions.py::test_meshgrid diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh index 1de8dfb..9464ee6 100644 --- a/_unittests/test_array_api.sh +++ b/_unittests/test_array_api.sh @@ -1,4 +1,4 @@ export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy -pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_ones || exit 1 +pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_scalars || exit 1 # pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 \ No newline at end of file diff --git a/azure-pipelines.yml b/azure-pipelines.yml index ca24462..c449f2e 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -48,7 +48,7 @@ jobs: vmImage: 'ubuntu-latest' strategy: matrix: - Python310-Linux: + Python311-Linux: python.version: '3.11' maxParallel: 3 @@ -96,7 +96,7 @@ jobs: strategy: matrix: Python310-Linux: - python.version: '3.11' + python.version: '3.10' maxParallel: 3 steps: @@ -149,7 +149,7 @@ jobs: vmImage: 'ubuntu-latest' strategy: matrix: - Python310-Linux: + Python311-Linux: python.version: '3.11' maxParallel: 3 @@ -202,7 +202,7 @@ jobs: vmImage: 'windows-latest' strategy: matrix: - Python310-Windows: + Python311-Windows: python.version: '3.11' maxParallel: 3 @@ -235,7 +235,7 @@ jobs: vmImage: 'macOS-latest' strategy: matrix: - Python310-Mac: + Python311-Mac: python.version: '3.11' maxParallel: 3 pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy