From 2749bf1a4f3d510f550edf55d8cacca121f1bf9c Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 11 Jun 2023 13:11:39 +0200 Subject: [PATCH 1/9] Extends Array API to EagerOrt --- _unittests/ut_ort/test_ort_tensor.py | 48 ++++++++++++++++++++++++-- azure-pipelines.yml | 5 +++ onnx_array_api/array_api/__init__.py | 4 +++ onnx_array_api/array_api/onnx_numpy.py | 10 ++++++ onnx_array_api/array_api/onnx_ort.py | 25 +++++++++++++- onnx_array_api/npx/npx_jit_eager.py | 12 +++++++ onnx_array_api/npx/npx_tensors.py | 2 +- onnx_array_api/npx/npx_types.py | 11 ++++-- onnx_array_api/npx/npx_var.py | 9 +++++ onnx_array_api/ort/ort_tensors.py | 2 +- 10 files changed, 119 insertions(+), 9 deletions(-) diff --git a/_unittests/ut_ort/test_ort_tensor.py b/_unittests/ut_ort/test_ort_tensor.py index 57340d5..b673557 100644 --- a/_unittests/ut_ort/test_ort_tensor.py +++ b/_unittests/ut_ort/test_ort_tensor.py @@ -1,18 +1,17 @@ import unittest from contextlib import redirect_stdout from io import StringIO - import numpy as np from onnx.defs import onnx_opset_version from onnx.reference import ReferenceEvaluator from onnxruntime import InferenceSession - from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.npx import eager_onnx, jit_onnx from onnx_array_api.npx.npx_functions import absolute as absolute_inline from onnx_array_api.npx.npx_functions import cdist as cdist_inline from onnx_array_api.npx.npx_functions_test import absolute -from onnx_array_api.npx.npx_types import Float32, Float64 +from onnx_array_api.npx.npx_functions import copy as copy_inline +from onnx_array_api.npx.npx_types import Float32, Float64, DType from onnx_array_api.npx.npx_var import Input from onnx_array_api.ort.ort_tensors import EagerOrtTensor, JitOrtTensor, OrtTensor @@ -193,6 +192,49 @@ def impl(xa, xb): if len(pieces) > 2: raise AssertionError(f"Function is not using argument:\n{onx}") + def test_astype(self): + f = absolute_inline(copy_inline(Input("A")).astype(np.float32)) + onx = f.to_onnx(constraints={"A": Float64[None]}) + x = np.array([[-5, 6]], dtype=np.float64) + z = np.abs(x.astype(np.float32)) + ref = InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + got = ref.run(None, {"A": x}) + self.assertEqualArray(z, got[0]) + + def test_astype0(self): + f = absolute_inline(copy_inline(Input("A")).astype(np.float32)) + onx = f.to_onnx(constraints={"A": Float64[None]}) + x = np.array(-5, dtype=np.float64) + z = np.abs(x.astype(np.float32)) + ref = InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + got = ref.run(None, {"A": x}) + self.assertEqualArray(z, got[0]) + + def test_eager_ort_cast(self): + def impl(A): + return A.astype(DType("FLOAT")) + + e = eager_onnx(impl) + self.assertEqual(len(e.versions), 0) + + # Float64 + x = np.array([0, 1, -2], dtype=np.float64) + z = x.astype(np.float32) + res = e(x) + self.assertEqualArray(z, res) + self.assertEqual(res.dtype, np.float32) + + # again + x = np.array(1, dtype=np.float64) + z = x.astype(np.float32) + res = e(x) + self.assertEqualArray(z, res) + self.assertEqual(res.dtype, np.float32) + if __name__ == "__main__": # TestNpx().test_eager_numpy() diff --git a/azure-pipelines.yml b/azure-pipelines.yml index defe983..c2ed20f 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -131,6 +131,11 @@ jobs: cd array-api-tests python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros displayName: "test_creation_functions.py::test_zeros" + - script: | + export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort + cd array-api-tests + python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros + displayName: "test_creation_functions.py::test_zeros" #- script: | # export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy # cd array-api-tests diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py index e13b184..217d575 100644 --- a/onnx_array_api/array_api/__init__.py +++ b/onnx_array_api/array_api/__init__.py @@ -3,6 +3,10 @@ def _finalize_array_api(module): + """ + Adds common attributes to Array API defined in this modules + such as types. + """ module.float16 = DType(TensorProto.FLOAT16) module.float32 = DType(TensorProto.FLOAT) module.float64 = DType(TensorProto.DOUBLE) diff --git a/onnx_array_api/array_api/onnx_numpy.py b/onnx_array_api/array_api/onnx_numpy.py index 79b339d..76601d1 100644 --- a/onnx_array_api/array_api/onnx_numpy.py +++ b/onnx_array_api/array_api/onnx_numpy.py @@ -58,10 +58,20 @@ def zeros( return generic_zeros( EagerNumpyTensor(np.array(shape, dtype=np.int64)), dtype=dtype, order=order ) + if isinstance(shape, int): + return generic_zeros( + EagerNumpyTensor(np.array([shape], dtype=np.int64)), + dtype=dtype, + order=order, + ) return generic_zeros(shape, dtype=dtype, order=order) def _finalize(): + """ + Adds common attributes to Array API defined in this modules + such as types. + """ from . import onnx_numpy _finalize_array_api(onnx_numpy) diff --git a/onnx_array_api/array_api/onnx_ort.py b/onnx_array_api/array_api/onnx_ort.py index 505efdf..557f73c 100644 --- a/onnx_array_api/array_api/onnx_ort.py +++ b/onnx_array_api/array_api/onnx_ort.py @@ -2,8 +2,9 @@ Array API valid for an :class:`EagerOrtTensor`. """ from typing import Optional, Any +import numpy as np +from onnx import TensorProto from ..ort.ort_tensors import EagerOrtTensor -from ..npx.npx_types import DType from ..npx.npx_functions import ( all, abs, @@ -14,6 +15,8 @@ reshape, take, ) +from ..npx.npx_types import DType, ElemType, TensorType, OptParType +from ..npx.npx_functions import zeros as generic_zeros from ._onnx_common import template_asarray from . import _finalize_array_api @@ -45,7 +48,27 @@ def asarray( ) +def zeros( + shape: TensorType[ElemType.int64, "I", (None,)], + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + if isinstance(shape, tuple): + return generic_zeros( + EagerOrtTensor(np.array(shape, dtype=np.int64)), dtype=dtype, order=order + ) + if isinstance(shape, int): + return generic_zeros( + EagerOrtTensor(np.array([shape], dtype=np.int64)), dtype=dtype, order=order + ) + return generic_zeros(shape, dtype=dtype, order=order) + + def _finalize(): + """ + Adds common attributes to Array API defined in this modules + such as types. + """ from . import onnx_ort _finalize_array_api(onnx_ort) diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index 85b52d4..35ff9af 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -267,6 +267,18 @@ def to_jit(self, *values, **kwargs): target_opsets=self.target_opsets, ir_version=self.ir_version, ) + if len(values) > 0 and len(values[0].shape) == 0: + inps = onx.graph.input[0] + shape = [] + for d in inps.type.tensor_type.shape.dim: + v = d.dim_value if d.dim_value > 0 else d.dim_param + shape.append(v) + if len(shape) != 0: + raise RuntimeError( + f"Shape mismatch, values[0]={values[0]} " + f"and inputs={onx.graph.input}." + ) + exe = self.tensor_class.create_function(names, onx) self.info("-", "to_jit") return onx, exe diff --git a/onnx_array_api/npx/npx_tensors.py b/onnx_array_api/npx/npx_tensors.py index e1e4b21..1e11408 100644 --- a/onnx_array_api/npx/npx_tensors.py +++ b/onnx_array_api/npx/npx_tensors.py @@ -77,9 +77,9 @@ def _getitem_impl_var(obj, index, method_name=None): def _astype_impl( x: TensorType[ElemType.allowed, "T1"], dtype: ParType[DType], method_name=None ) -> TensorType[ElemType.allowed, "T2"]: - # avoids circular imports. if dtype is None: raise ValueError("dtype cannot be None.") + # avoids circular imports. from .npx_var import Var if not isinstance(x, Var): diff --git a/onnx_array_api/npx/npx_types.py b/onnx_array_api/npx/npx_types.py index aa335bd..b9b05f2 100644 --- a/onnx_array_api/npx/npx_types.py +++ b/onnx_array_api/npx/npx_types.py @@ -18,13 +18,18 @@ class DType(WrapperType): Type of the element type returned by tensors following the :epkg:`Array API`. - :param code: element type based on onnx definition + :param code: element type based on onnx definition, + if str, it looks into class :class:`onnxTensorProto` + to retrieve the code """ __slots__ = ["code_"] - def __init__(self, code: int): - self.code_ = code + def __init__(self, code: Union[int, str]): + if isinstance(code, str): + self.code_ = getattr(TensorProto, code) + else: + self.code_ = code def __repr__(self) -> str: "usual" diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index ae5b732..01be4e7 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -199,6 +199,15 @@ class Var(BaseArrayApi): :param onnx_input_type_: names given to the variables """ + def __array_namespace__(self, api_version: Optional[str] = None): + """ + Raises an exception if called. + """ + raise RuntimeError( + f"This function should never be called for class {type(self)}. " + f"It should be called for an eager tensor." + ) + @staticmethod def get_cst_var(): from .npx_core_api import cst, var diff --git a/onnx_array_api/ort/ort_tensors.py b/onnx_array_api/ort/ort_tensors.py index ead834d..37e8386 100644 --- a/onnx_array_api/ort/ort_tensors.py +++ b/onnx_array_api/ort/ort_tensors.py @@ -148,7 +148,7 @@ def ndim(self): @property def shape(self) -> Tuple[int, ...]: "Returns the shape of the tensor." - return self._tensor.shape() + return tuple(self._tensor.shape()) @property def dtype(self) -> DType: From 0e73dda56c7b650e17252483f671dddadc1fb66e Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 12 Jun 2023 11:50:25 +0200 Subject: [PATCH 2/9] fix empty shape --- .gitignore | 1 + onnx_array_api/npx/npx_graph_builder.py | 2 +- onnx_array_api/npx/npx_numpy_tensors.py | 9 ++++----- onnx_array_api/npx/npx_var.py | 2 +- onnx_array_api/ort/ort_tensors.py | 9 ++++----- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 136737c..f4d6253 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ _cache/* dist/* build/* .eggs/* +.hypothesis/* *egg-info/* _doc/auto_examples/* _doc/examples/_cache/* diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index ec91b91..d41b91c 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -273,7 +273,7 @@ def _io( self, index: int, name: str, tensor_type: Optional[type], is_input: bool ) -> ValueInfoProto: """ - Converts an input or outut into :class:`onnx.ValueInfoProto`. + Converts an input or output into :class:`onnx.ValueInfoProto`. :param index: index of the input or output to add :param name: input or output name diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index e1a0c10..404a8b1 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -107,13 +107,12 @@ def dims(self): """ Returns the dimensions of the tensor. First dimension is the batch dimension if the tensor - has more than one dimension. + has more than one dimension. It is always left undefined. """ - if len(self._tensor.shape) == 0: - return (0,) - if len(self._tensor.shape) == 1: + if len(self._tensor.shape) <= 1: + # a scalar (len==0) or a 1D tensor return self._tensor.shape - return (None,) + self._tensor.shape[1:] + return (None, *tuple(self.shape[1:])) @property def ndim(self): diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index 01be4e7..d6b1ac1 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -985,7 +985,7 @@ def __getitem__(self, index: Any) -> "Var": cst, var = Var.get_cst_var() if self.n_var_outputs != 1: - # Multioutut + # Multioutput if not isinstance(index, int): raise TypeError( f"Only indices are allowed when selecting an output, " diff --git a/onnx_array_api/ort/ort_tensors.py b/onnx_array_api/ort/ort_tensors.py index 37e8386..037a64c 100644 --- a/onnx_array_api/ort/ort_tensors.py +++ b/onnx_array_api/ort/ort_tensors.py @@ -175,12 +175,11 @@ def dims(self): """ Returns the dimensions of the tensor. First dimension is the batch dimension if the tensor - has more than one dimension. + has more than one dimension. It is always left undefined. """ - if len(self.shape) == 0: - return (0,) - if len(self.shape) == 1: - return tuple(self.shape) + if len(self._tensor.shape) <= 1: + # a scalar (len==0) or a 1D tensor + return self._tensor.shape return (None, *tuple(self.shape[1:])) @property From 62b82207697566ac459bd29e7441ef227ad52bde Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 12 Jun 2023 12:00:29 +0200 Subject: [PATCH 3/9] fix shape --- _unittests/ut_array_api/test_array_apis.py | 23 ++++++++++++++++++++++ _unittests/ut_array_api/test_onnx_numpy.py | 4 ++-- _unittests/ut_array_api/test_onnx_ort.py | 20 +++++++++++++++++++ onnx_array_api/ort/ort_tensors.py | 4 ++-- 4 files changed, 47 insertions(+), 4 deletions(-) create mode 100644 _unittests/ut_array_api/test_array_apis.py create mode 100644 _unittests/ut_array_api/test_onnx_ort.py diff --git a/_unittests/ut_array_api/test_array_apis.py b/_unittests/ut_array_api/test_array_apis.py new file mode 100644 index 0000000..a85cca7 --- /dev/null +++ b/_unittests/ut_array_api/test_array_apis.py @@ -0,0 +1,23 @@ +import unittest +import numpy as np +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.array_api import onnx_numpy as xpn +from onnx_array_api.array_api import onnx_ort as xpo +# from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor +# from onnx_array_api.ort.ort_tensors import EagerOrtTensor + + +class TestArraysApis(ExtTestCase): + def test_zeros_numpy_1(self): + c = xpn.zeros(1) + d = c.numpy() + self.assertEqualArray(np.array([0], dtype=np.float32), d) + + def test_zeros_ort_1(self): + c = xpo.zeros(1) + d = c.numpy() + self.assertEqualArray(np.array([0], dtype=np.float32), d) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index 30e2ca2..bdf870c 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -2,12 +2,12 @@ import numpy as np from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.array_api import onnx_numpy as xp -from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor +from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor as EagerTensor class TestOnnxNumpy(ExtTestCase): def test_abs(self): - c = EagerNumpyTensor(np.array([4, 5], dtype=np.int64)) + c = EagerTensor(np.array([4, 5], dtype=np.int64)) mat = xp.zeros(c, dtype=xp.int64) matnp = mat.numpy() self.assertEqual(matnp.shape, (4, 5)) diff --git a/_unittests/ut_array_api/test_onnx_ort.py b/_unittests/ut_array_api/test_onnx_ort.py new file mode 100644 index 0000000..a10b0d0 --- /dev/null +++ b/_unittests/ut_array_api/test_onnx_ort.py @@ -0,0 +1,20 @@ +import unittest +import numpy as np +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.array_api import onnx_ort as xp +from onnx_array_api.ort.ort_tensors import EagerOrtTensor as EagerTensor + + +class TestOnnxOrt(ExtTestCase): + def test_abs(self): + c = EagerTensor(np.array([4, 5], dtype=np.int64)) + mat = xp.zeros(c, dtype=xp.int64) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (4, 5)) + self.assertNotEmpty(matnp[0, 0]) + a = xp.absolute(mat) + self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_array_api/ort/ort_tensors.py b/onnx_array_api/ort/ort_tensors.py index 037a64c..db9d4d5 100644 --- a/onnx_array_api/ort/ort_tensors.py +++ b/onnx_array_api/ort/ort_tensors.py @@ -177,9 +177,9 @@ def dims(self): First dimension is the batch dimension if the tensor has more than one dimension. It is always left undefined. """ - if len(self._tensor.shape) <= 1: + if len(self._tensor.shape()) <= 1: # a scalar (len==0) or a 1D tensor - return self._tensor.shape + return tuple(self._tensor.shape()) return (None, *tuple(self.shape[1:])) @property From 2d26568ce0003d1d0eb738e265346d1f3836a438 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 12 Jun 2023 12:08:37 +0200 Subject: [PATCH 4/9] fix azure --- azure-pipelines.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index c2ed20f..1b4b9c8 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -110,6 +110,8 @@ jobs: displayName: 'Install tools' - script: pip install -r requirements.txt displayName: 'Install Requirements' + - script: pip install onnxruntime + displayName: 'Install onnxruntime' - script: python setup.py install displayName: 'Install onnx_array_api' - script: | @@ -130,12 +132,12 @@ jobs: export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy cd array-api-tests python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros - displayName: "test_creation_functions.py::test_zeros" + displayName: "numpy test_zeros" - script: | export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort cd array-api-tests python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros - displayName: "test_creation_functions.py::test_zeros" + displayName: "ort test_zeros" #- script: | # export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy # cd array-api-tests From 857e3b0e1a903e9d56736447d2496c28f9498a17 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 12 Jun 2023 14:50:51 +0200 Subject: [PATCH 5/9] refactoring --- _unittests/onnx-numpy-skips.txt | 13 ++++ _unittests/onnx-ort-skips.txt | 15 ++++ _unittests/ut_array_api/test_array_apis.py | 89 ++++++++++++++++++++++ azure-pipelines.yml | 10 +-- onnx_array_api/_helpers.py | 45 +++++++++++ onnx_array_api/array_api/__init__.py | 32 ++++++++ onnx_array_api/array_api/_onnx_common.py | 2 + onnx_array_api/array_api/onnx_numpy.py | 24 ++++++ onnx_array_api/array_api/onnx_ort.py | 4 + onnx_array_api/npx/npx_functions.py | 43 +++++++---- onnx_array_api/npx/npx_numpy_tensors.py | 4 +- onnx_array_api/npx/npx_tensors.py | 8 +- onnx_array_api/npx/npx_types.py | 9 ++- onnx_array_api/npx/npx_var.py | 39 +--------- 14 files changed, 271 insertions(+), 66 deletions(-) create mode 100644 _unittests/onnx-numpy-skips.txt create mode 100644 _unittests/onnx-ort-skips.txt create mode 100644 onnx_array_api/_helpers.py diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt new file mode 100644 index 0000000..3cdbb31 --- /dev/null +++ b/_unittests/onnx-numpy-skips.txt @@ -0,0 +1,13 @@ +# API failures +array_api_tests/test_creation_functions.py::test_arange +array_api_tests/test_creation_functions.py::test_asarray_scalars +array_api_tests/test_creation_functions.py::test_asarray_arrays +array_api_tests/test_creation_functions.py::test_empty +array_api_tests/test_creation_functions.py::test_empty_like +array_api_tests/test_creation_functions.py::test_eye +array_api_tests/test_creation_functions.py::test_full +array_api_tests/test_creation_functions.py::test_full_like +array_api_tests/test_creation_functions.py::test_linspace +array_api_tests/test_creation_functions.py::test_meshgrid +array_api_tests/test_creation_functions.py::test_ones_like +array_api_tests/test_creation_functions.py::test_zeros_like diff --git a/_unittests/onnx-ort-skips.txt b/_unittests/onnx-ort-skips.txt new file mode 100644 index 0000000..557d14b --- /dev/null +++ b/_unittests/onnx-ort-skips.txt @@ -0,0 +1,15 @@ +# Not implementated by onnxruntime +array_api_tests/test_creation_functions.py::test_arange +array_api_tests/test_creation_functions.py::test_asarray_scalars +array_api_tests/test_creation_functions.py::test_asarray_arrays +array_api_tests/test_creation_functions.py::test_empty +array_api_tests/test_creation_functions.py::test_empty_like +array_api_tests/test_creation_functions.py::test_eye +array_api_tests/test_creation_functions.py::test_full +array_api_tests/test_creation_functions.py::test_full_like +array_api_tests/test_creation_functions.py::test_linspace +array_api_tests/test_creation_functions.py::test_meshgrid +array_api_tests/test_creation_functions.py::test_ones +array_api_tests/test_creation_functions.py::test_ones_like +array_api_tests/test_creation_functions.py::test_zeros +array_api_tests/test_creation_functions.py::test_zeros_like diff --git a/_unittests/ut_array_api/test_array_apis.py b/_unittests/ut_array_api/test_array_apis.py index a85cca7..c72700c 100644 --- a/_unittests/ut_array_api/test_array_apis.py +++ b/_unittests/ut_array_api/test_array_apis.py @@ -1,8 +1,10 @@ import unittest +from inspect import isfunction, ismethod import numpy as np from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.array_api import onnx_numpy as xpn from onnx_array_api.array_api import onnx_ort as xpo + # from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor # from onnx_array_api.ort.ort_tensors import EagerOrtTensor @@ -18,6 +20,93 @@ def test_zeros_ort_1(self): d = c.numpy() self.assertEqualArray(np.array([0], dtype=np.float32), d) + def test_ffinfo(self): + dt = np.float32 + fi1 = np.finfo(dt) + fi2 = xpn.finfo(dt) + fi3 = xpo.finfo(dt) + dt1 = fi1.dtype + dt2 = fi2.dtype + dt3 = fi3.dtype + self.assertEqual(dt2, dt3) + self.assertNotEqual(dt1.__class__, dt2.__class__) + mi1 = fi1.min + mi2 = fi2.min + self.assertEqual(mi1, mi2) + mi1 = fi1.smallest_normal + mi2 = fi2.smallest_normal + self.assertEqual(mi1, mi2) + for n in dir(fi1): + if n.startswith("__"): + continue + if n in {"machar"}: + continue + v1 = getattr(fi1, n) + with self.subTest(att=n): + v2 = getattr(fi2, n) + v3 = getattr(fi3, n) + if isfunction(v1) or ismethod(v1): + try: + v1 = v1() + except TypeError: + continue + v2 = v2() + v3 = v3() + if v1 != v2: + raise AssertionError( + f"12: info disagree on name {n!r}: {v1} != {v2}, " + f"type(v1)={type(v1)}, type(v2)={type(v2)}, " + f"ismethod={ismethod(v1)}." + ) + if v2 != v3: + raise AssertionError( + f"23: info disagree on name {n!r}: {v2} != {v3}, " + f"type(v1)={type(v1)}, type(v2)={type(v2)}, " + f"ismethod={ismethod(v1)}." + ) + + def test_iiinfo(self): + dt = np.int64 + fi1 = np.iinfo(dt) + fi2 = xpn.iinfo(dt) + fi3 = xpo.iinfo(dt) + dt1 = fi1.dtype + dt2 = fi2.dtype + dt3 = fi3.dtype + self.assertEqual(dt2, dt3) + self.assertNotEqual(dt1.__class__, dt2.__class__) + mi1 = fi1.min + mi2 = fi2.min + self.assertEqual(mi1, mi2) + for n in dir(fi1): + if n.startswith("__"): + continue + if n in {"machar"}: + continue + v1 = getattr(fi1, n) + with self.subTest(att=n): + v2 = getattr(fi2, n) + v3 = getattr(fi3, n) + if isfunction(v1) or ismethod(v1): + try: + v1 = v1() + except TypeError: + continue + v2 = v2() + v3 = v3() + if v1 != v2: + raise AssertionError( + f"12: info disagree on name {n!r}: {v1} != {v2}, " + f"type(v1)={type(v1)}, type(v2)={type(v2)}, " + f"ismethod={ismethod(v1)}." + ) + if v2 != v3: + raise AssertionError( + f"23: info disagree on name {n!r}: {v2} != {v3}, " + f"type(v1)={type(v1)}, type(v2)={type(v2)}, " + f"ismethod={ismethod(v1)}." + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 1b4b9c8..e55fcfe 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -131,13 +131,13 @@ jobs: - script: | export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy cd array-api-tests - python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros - displayName: "numpy test_zeros" + python -m pytest -x array_api_tests/test_creation_functions.py --skips_file=../_unittests/onnx-numpy-skips.txt -v + displayName: "numpy test_creation_functions.py" - script: | - export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort + export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort --skips_file=../_unittests/onnx-numpy-skips.txt -v cd array-api-tests - python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros - displayName: "ort test_zeros" + python -m pytest -x array_api_tests/test_creation_functions.py + displayName: "ort test_creation_functions.py" #- script: | # export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy # cd array-api-tests diff --git a/onnx_array_api/_helpers.py b/onnx_array_api/_helpers.py new file mode 100644 index 0000000..6191c92 --- /dev/null +++ b/onnx_array_api/_helpers.py @@ -0,0 +1,45 @@ +import numpy as np +from typing import Any +from onnx import helper, TensorProto + + +def np_dtype_to_tensor_dtype(dtype: Any): + """ + Improves :func:`onnx.helper.np_dtype_to_tensor_dtype`. + """ + try: + dt = helper.np_dtype_to_tensor_dtype(dtype) + except KeyError: + if dtype == np.float32: + dt = TensorProto.FLOAT + elif dtype == np.float64: + dt = TensorProto.DOUBLE + elif dtype == np.int64: + dt = TensorProto.INT64 + elif dtype == np.int32: + dt = TensorProto.INT32 + elif dtype == np.int16: + dt = TensorProto.INT16 + elif dtype == np.int8: + dt = TensorProto.INT8 + elif dtype == np.uint64: + dt = TensorProto.UINT64 + elif dtype == np.uint32: + dt = TensorProto.UINT32 + elif dtype == np.uint16: + dt = TensorProto.UINT16 + elif dtype == np.uint8: + dt = TensorProto.UINT8 + elif dtype == np.float16: + dt = TensorProto.FLOAT16 + elif dtype in (bool, np.bool_): + dt = TensorProto.BOOL + elif dtype in (str, np.str_): + dt = TensorProto.STRING + elif dtype is int: + dt = TensorProto.INT64 + elif dtype is float: + dt = TensorProto.FLOAT64 + else: + raise KeyError(f"Unable to guess type for dtype={dtype}.") + return dt diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py index 217d575..cc64b8e 100644 --- a/onnx_array_api/array_api/__init__.py +++ b/onnx_array_api/array_api/__init__.py @@ -1,7 +1,37 @@ +import numpy as np from onnx import TensorProto +from .._helpers import np_dtype_to_tensor_dtype from ..npx.npx_types import DType +def _finfo(dtype): + """ + Similar to :class:`numpy.finfo`. + """ + dt = dtype.np_dtype if isinstance(dtype, DType) else dtype + res = np.finfo(dt) + d = res.__dict__.copy() + d["dtype"] = DType(np_dtype_to_tensor_dtype(dt)) + nres = type("finfo", (res.__class__,), d) + setattr(nres, "smallest_normal", res.smallest_normal) + setattr(nres, "tiny", res.tiny) + return nres + + +def _iinfo(dtype): + """ + Similar to :class:`numpy.finfo`. + """ + dt = dtype.np_dtype if isinstance(dtype, DType) else dtype + res = np.iinfo(dt) + d = res.__dict__.copy() + d["dtype"] = DType(np_dtype_to_tensor_dtype(dt)) + nres = type("finfo", (res.__class__,), d) + setattr(nres, "min", res.min) + setattr(nres, "max", res.max) + return nres + + def _finalize_array_api(module): """ Adds common attributes to Array API defined in this modules @@ -21,3 +51,5 @@ def _finalize_array_api(module): module.bfloat16 = DType(TensorProto.BFLOAT16) setattr(module, "bool", DType(TensorProto.BOOL)) setattr(module, "str", DType(TensorProto.STRING)) + setattr(module, "finfo", _finfo) + setattr(module, "iinfo", _iinfo) diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 8d136c4..25ace54 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -41,6 +41,8 @@ def template_asarray( v = TEagerTensor(np.array(a, dtype=np.bool_)) elif isinstance(a, str): v = TEagerTensor(np.array(a, dtype=np.str_)) + elif isinstance(a, list): + v = TEagerTensor(np.array(a)) else: raise RuntimeError(f"Unexpected type {type(a)} for the first input.") if dtype is not None: diff --git a/onnx_array_api/array_api/onnx_numpy.py b/onnx_array_api/array_api/onnx_numpy.py index 76601d1..c20fb15 100644 --- a/onnx_array_api/array_api/onnx_numpy.py +++ b/onnx_array_api/array_api/onnx_numpy.py @@ -11,9 +11,12 @@ astype, equal, isdtype, + isfinite, + isnan, reshape, take, ) +from ..npx.npx_functions import ones as generic_ones from ..npx.npx_functions import zeros as generic_zeros from ..npx.npx_numpy_tensors import EagerNumpyTensor from ..npx.npx_types import DType, ElemType, TensorType, OptParType @@ -28,6 +31,9 @@ "astype", "equal", "isdtype", + "isfinite", + "isnan", + "ones", "reshape", "take", "zeros", @@ -49,6 +55,24 @@ def asarray( ) +def ones( + shape: TensorType[ElemType.int64, "I", (None,)], + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + if isinstance(shape, tuple): + return generic_ones( + EagerNumpyTensor(np.array(shape, dtype=np.int64)), dtype=dtype, order=order + ) + if isinstance(shape, int): + return generic_ones( + EagerNumpyTensor(np.array([shape], dtype=np.int64)), + dtype=dtype, + order=order, + ) + return generic_ones(shape, dtype=dtype, order=order) + + def zeros( shape: TensorType[ElemType.int64, "I", (None,)], dtype: OptParType[DType] = DType(TensorProto.FLOAT), diff --git a/onnx_array_api/array_api/onnx_ort.py b/onnx_array_api/array_api/onnx_ort.py index 557f73c..56f6444 100644 --- a/onnx_array_api/array_api/onnx_ort.py +++ b/onnx_array_api/array_api/onnx_ort.py @@ -12,6 +12,8 @@ astype, equal, isdtype, + isnan, + isfinite, reshape, take, ) @@ -28,6 +30,8 @@ "astype", "equal", "isdtype", + "isfinite", + "isnan", "reshape", "take", ] diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index b55cf4d..29a4481 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -1,11 +1,10 @@ from typing import Optional, Tuple, Union - import array_api_compat.numpy as np_array_api import numpy as np from onnx import FunctionProto, ModelProto, NodeProto, TensorProto -from onnx.helper import make_tensor, np_dtype_to_tensor_dtype, tensor_dtype_to_np_dtype +from onnx.helper import make_tensor, tensor_dtype_to_np_dtype from onnx.numpy_helper import from_array - +from .._helpers import np_dtype_to_tensor_dtype from .npx_constants import FUNCTION_DOMAIN from .npx_core_api import cst, make_tuple, npxapi_inline, npxapi_no_inline, var from .npx_types import ( @@ -203,15 +202,7 @@ def astype( raise TypeError( f"dtype is an attribute, it cannot be a Variable of type {type(dtype)}." ) - try: - to = np_dtype_to_tensor_dtype(dtype) - except KeyError: - if dtype is int: - to = TensorProto.INT64 - elif dtype is float: - to = TensorProto.float64 - else: - raise ValueError(f"Unable to guess tensor type from {dtype}.") + to = np_dtype_to_tensor_dtype(dtype) return var(a, op="Cast", to=to) @@ -351,7 +342,7 @@ def einsum( def equal( x: TensorType[ElemType.allowed, "T"], y: TensorType[ElemType.allowed, "T"] ) -> TensorType[ElemType.bool_, "T1"]: - "See :func:`numpy.isnan`." + "See :func:`numpy.equal`." return var(x, y, op="Equal") @@ -437,6 +428,12 @@ def isdtype( return np_array_api.isdtype(dtype, kind) +@npxapi_inline +def isfinite(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T1"]: + "See :func:`numpy.isfinite`." + return var(x, op="IsInf") + + @npxapi_inline def isnan(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T1"]: "See :func:`numpy.isnan`." @@ -464,6 +461,26 @@ def matmul( return var(a, b, op="MatMul") +@npxapi_inline +def ones( + shape: TensorType[ElemType.int64, "I", (None,)], + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + """ + Implements :func:`numpy.zeros`. + """ + if order != "C": + raise RuntimeError(f"order={order!r} != 'C' not supported.") + if dtype is None: + dtype = DType(TensorProto.FLOAT) + return var( + shape, + value=make_tensor(name="one", data_type=dtype.code, dims=[1], vals=[1]), + op="ConstantOfShape", + ) + + @npxapi_inline def pad( x: TensorType[ElemType.numerics, "T"], diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index 404a8b1..15f9588 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -1,10 +1,8 @@ from typing import Any, Callable, List, Optional, Tuple - import numpy as np from onnx import ModelProto -from onnx.helper import np_dtype_to_tensor_dtype from onnx.reference import ReferenceEvaluator - +from .._helpers import np_dtype_to_tensor_dtype from .npx_numpy_tensors_ops import ConstantOfShape from .npx_tensors import EagerTensor, JitTensor from .npx_types import DType, TensorType diff --git a/onnx_array_api/npx/npx_tensors.py b/onnx_array_api/npx/npx_tensors.py index 1e11408..b0e92c2 100644 --- a/onnx_array_api/npx/npx_tensors.py +++ b/onnx_array_api/npx/npx_tensors.py @@ -1,8 +1,6 @@ from typing import Any, Union - import numpy as np -from onnx.helper import np_dtype_to_tensor_dtype - +from .._helpers import np_dtype_to_tensor_dtype from .npx_types import DType, ElemType, ParType, TensorType from .npx_array_api import BaseArrayApi, ArrayApiError @@ -178,10 +176,6 @@ def _generic_method_reduce(self, method_name, *args: Any, **kwargs: Any) -> Any: @staticmethod def _np_dtype_to_tensor_dtype(dtype): - if dtype == int: - dtype = np.dtype("int64") - elif dtype == float: - dtype = np.dtype("float64") return np_dtype_to_tensor_dtype(dtype) def _generic_method_astype( diff --git a/onnx_array_api/npx/npx_types.py b/onnx_array_api/npx/npx_types.py index b9b05f2..6063e64 100644 --- a/onnx_array_api/npx/npx_types.py +++ b/onnx_array_api/npx/npx_types.py @@ -2,7 +2,8 @@ import numpy as np from onnx import AttributeProto, TensorProto -from onnx.helper import np_dtype_to_tensor_dtype, tensor_dtype_to_np_dtype +from onnx.helper import tensor_dtype_to_np_dtype +from .._helpers import np_dtype_to_tensor_dtype class WrapperType: @@ -28,8 +29,10 @@ class DType(WrapperType): def __init__(self, code: Union[int, str]): if isinstance(code, str): self.code_ = getattr(TensorProto, code) - else: + elif isinstance(code, int): self.code_ = code + else: + raise TypeError(f"Unsupported type {type(code)}:{code!r}") def __repr__(self) -> str: "usual" @@ -60,6 +63,8 @@ def __eq__(self, dt: "DType") -> bool: return self.code_ == TensorProto.STRING if dt is bool: return self.code_ == TensorProto.BOOL + if isinstance(dt, list): + return False if dt in ElemType.numpy_map: dti = ElemType.numpy_map[dt] return self.code_ == dti.code_ diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index d6b1ac1..42b1b5a 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -1,9 +1,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union - import numpy as np -from onnx import FunctionProto, ModelProto, NodeProto, TensorProto -from onnx.helper import np_dtype_to_tensor_dtype - +from onnx import FunctionProto, ModelProto, NodeProto +from .._helpers import np_dtype_to_tensor_dtype from .npx_array_api import BaseArrayApi, ArrayApiError from .npx_constants import DEFAULT_OPSETS, ONNX_DOMAIN from .npx_types import DType, ElemType, OptParType, ParType, TensorType, TupleType @@ -831,38 +829,7 @@ def astype(self, dtype) -> "Var": if isinstance(dtype, Var): return var(self.self_var, dtype, op="CastLike") if not isinstance(dtype, int): - try: - dtype = np_dtype_to_tensor_dtype(dtype) - except KeyError: - if dtype == np.float32: - dtype = TensorProto.FLOAT - elif dtype == np.float64: - dtype = TensorProto.DOUBLE - elif dtype == np.int64: - dtype = TensorProto.INT64 - elif dtype == np.int32: - dtype = TensorProto.INT32 - elif dtype == np.int16: - dtype = TensorProto.INT16 - elif dtype == np.int8: - dtype = TensorProto.INT8 - elif dtype == np.uint64: - dtype = TensorProto.UINT64 - elif dtype == np.uint32: - dtype = TensorProto.UINT32 - elif dtype == np.uint16: - dtype = TensorProto.UINT16 - elif dtype == np.uint8: - dtype = TensorProto.UINT8 - elif dtype == np.float16: - dtype = TensorProto.FLOAT16 - elif dtype in (bool, np.bool_): - dtype = TensorProto.BOOL - elif dtype in (str, np.str_): - dtype = TensorProto.STRING - else: - raise RuntimeError(f"Unable to guess type for dtype={dtype}.") - + dtype = np_dtype_to_tensor_dtype(dtype) return var(self.self_var, op="Cast", to=dtype) @property From 6b949cbd3db75982cf0fa508ec709d3fe007bafe Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Mon, 12 Jun 2023 14:56:36 +0200 Subject: [PATCH 6/9] fix command line --- azure-pipelines.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index e55fcfe..761e3de 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -131,10 +131,10 @@ jobs: - script: | export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy cd array-api-tests - python -m pytest -x array_api_tests/test_creation_functions.py --skips_file=../_unittests/onnx-numpy-skips.txt -v + python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt -v displayName: "numpy test_creation_functions.py" - script: | - export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort --skips_file=../_unittests/onnx-numpy-skips.txt -v + export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort --skips-file=../_unittests/onnx-numpy-skips.txt -v cd array-api-tests python -m pytest -x array_api_tests/test_creation_functions.py displayName: "ort test_creation_functions.py" From 10d438eeb5f21a6de0baa2f63c1ed62661475749 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 13 Jun 2023 00:30:05 +0200 Subject: [PATCH 7/9] CI --- azure-pipelines.yml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 761e3de..1ff7ace 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -253,16 +253,8 @@ jobs: displayName: 'export' - script: gcc --version displayName: 'gcc version' - - script: brew install llvm - displayName: 'install llvm' - - script: brew install libomp - displayName: 'Install omp' - - script: brew install p7zip - displayName: 'Install p7zip' - script: python -m pip install --upgrade pip setuptools wheel displayName: 'Install tools' - - script: brew install pybind11 - displayName: 'Install pybind11' - script: pip install -r requirements.txt displayName: 'Install Requirements' - script: pip install -r requirements-dev.txt From 0d0e9275566de4b9774c42055701c7296621fe9a Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 13 Jun 2023 00:43:03 +0200 Subject: [PATCH 8/9] fix CI --- azure-pipelines.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 1ff7ace..018e915 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -134,9 +134,9 @@ jobs: python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt -v displayName: "numpy test_creation_functions.py" - script: | - export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort --skips-file=../_unittests/onnx-numpy-skips.txt -v + export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort cd array-api-tests - python -m pytest -x array_api_tests/test_creation_functions.py + python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt -v displayName: "ort test_creation_functions.py" #- script: | # export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy From 73585c518ce9d3a833648910869dc6f7b7323d0d Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Tue, 13 Jun 2023 00:54:46 +0200 Subject: [PATCH 9/9] fix CI --- azure-pipelines.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 018e915..b711ecf 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -136,7 +136,7 @@ jobs: - script: | export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_ort cd array-api-tests - python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-numpy-skips.txt -v + python -m pytest -x array_api_tests/test_creation_functions.py --skips-file=../_unittests/onnx-ort-skips.txt -v displayName: "ort test_creation_functions.py" #- script: | # export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy