diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 4e5aeb5..e807b02 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,4 +4,5 @@ Change Logs 0.2.0 +++++ -* :pr:`3`: fixes Array API with onnxruntime +* :pr:`17`: implements ArrayAPI +* :pr:`3`: fixes Array API with onnxruntime and scikit-learn diff --git a/_doc/api/array_api.rst b/_doc/api/array_api.rst new file mode 100644 index 0000000..f07716a --- /dev/null +++ b/_doc/api/array_api.rst @@ -0,0 +1,7 @@ +onnx_array_api.array_api +======================== + +.. toctree:: + + array_api_onnx_numpy + array_api_onnx_ort diff --git a/_doc/api/array_api_numpy.rst b/_doc/api/array_api_numpy.rst new file mode 100644 index 0000000..f57089a --- /dev/null +++ b/_doc/api/array_api_numpy.rst @@ -0,0 +1,5 @@ +onnx_array_api.array_api.onnx_numpy +============================================= + +.. automodule:: onnx_array_api.array_api.onnx_numpy + :members: diff --git a/_doc/api/array_api_ort.rst b/_doc/api/array_api_ort.rst new file mode 100644 index 0000000..cc21311 --- /dev/null +++ b/_doc/api/array_api_ort.rst @@ -0,0 +1,5 @@ +onnx_array_api.array_api.onnx_ort +================================= + +.. automodule:: onnx_array_api.array_api.onnx_ort + :members: diff --git a/_doc/api/index.rst b/_doc/api/index.rst index 7750a5b..75c0aa4 100644 --- a/_doc/api/index.rst +++ b/_doc/api/index.rst @@ -6,6 +6,7 @@ API .. toctree:: :maxdepth: 1 + array_api npx_functions npx_var npx_jit diff --git a/_doc/api/npx_annot.rst b/_doc/api/npx_annot.rst index d7e46e3..43de2d7 100644 --- a/_doc/api/npx_annot.rst +++ b/_doc/api/npx_annot.rst @@ -1,29 +1,54 @@ +============= npx.npx_types ============= +DType +===== + +.. autoclass:: onnx_array_api.npx.npx_types.DType + :members: + Annotations -+++++++++++ +=========== + +ElemType +++++++++ .. autoclass:: onnx_array_api.npx.npx_types.ElemType :members: +ParType ++++++++ + .. autoclass:: onnx_array_api.npx.npx_types.ParType :members: +OptParType +++++++++++ + .. autoclass:: onnx_array_api.npx.npx_types.OptParType :members: +TensorType +++++++++++ + .. autoclass:: onnx_array_api.npx.npx_types.TensorType :members: +SequenceType +++++++++++++ + .. autoclass:: onnx_array_api.npx.npx_types.SequenceType :members: +TupleType ++++++++++ + .. autoclass:: onnx_array_api.npx.npx_types.TupleType :members: Shortcuts -+++++++++ +========= .. autoclass:: onnx_array_api.npx.npx_types.Bool diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh new file mode 100644 index 0000000..b32ee41 --- /dev/null +++ b/_unittests/test_array_api.sh @@ -0,0 +1,2 @@ +export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy +pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros \ No newline at end of file diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py new file mode 100644 index 0000000..30e2ca2 --- /dev/null +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -0,0 +1,20 @@ +import unittest +import numpy as np +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.array_api import onnx_numpy as xp +from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor + + +class TestOnnxNumpy(ExtTestCase): + def test_abs(self): + c = EagerNumpyTensor(np.array([4, 5], dtype=np.int64)) + mat = xp.zeros(c, dtype=xp.int64) + matnp = mat.numpy() + self.assertEqual(matnp.shape, (4, 5)) + self.assertNotEmpty(matnp[0, 0]) + a = xp.absolute(mat) + self.assertEqualArray(np.absolute(mat.numpy()), a.numpy()) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_npx/test_npx.py b/_unittests/ut_npx/test_npx.py index f550896..c9ee35f 100644 --- a/_unittests/ut_npx/test_npx.py +++ b/_unittests/ut_npx/test_npx.py @@ -29,6 +29,7 @@ npxapi_inline, ) from onnx_array_api.npx.npx_functions import absolute as absolute_inline +from onnx_array_api.npx.npx_functions import all as all_inline from onnx_array_api.npx.npx_functions import arange as arange_inline from onnx_array_api.npx.npx_functions import arccos as arccos_inline from onnx_array_api.npx.npx_functions import arccosh as arccosh_inline @@ -50,6 +51,7 @@ from onnx_array_api.npx.npx_functions import det as det_inline from onnx_array_api.npx.npx_functions import dot as dot_inline from onnx_array_api.npx.npx_functions import einsum as einsum_inline +from onnx_array_api.npx.npx_functions import equal as equal_inline from onnx_array_api.npx.npx_functions import erf as erf_inline from onnx_array_api.npx.npx_functions import exp as exp_inline from onnx_array_api.npx.npx_functions import expand_dims as expand_dims_inline @@ -95,6 +97,7 @@ from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor from onnx_array_api.npx.npx_types import ( Bool, + DType, Float32, Float64, Int64, @@ -127,18 +130,25 @@ def test_tensor(self): self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) self.assertEmpty(dt.shape) self.assertEqual(dt.type_name(), "TensorType['float32']") + dt = TensorType["float32"] self.assertEqual(len(dt.dtypes), 1) self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) self.assertEqual(dt.type_name(), "TensorType['float32']") + dt = TensorType[np.float32] self.assertEqual(len(dt.dtypes), 1) self.assertEqual(dt.dtypes[0].dtype, ElemType.float32) self.assertEqual(dt.type_name(), "TensorType['float32']") self.assertEmpty(dt.shape) + dt = TensorType[np.str_] + self.assertEqual(len(dt.dtypes), 1) + self.assertEqual(dt.dtypes[0].dtype, ElemType.str_) + self.assertEqual(dt.type_name(), "TensorType[strings]") + self.assertEmpty(dt.shape) + self.assertRaise(lambda: TensorType[None], TypeError) - self.assertRaise(lambda: TensorType[np.str_], TypeError) self.assertRaise(lambda: TensorType[{np.float32, np.str_}], TypeError) def test_superset(self): @@ -1155,6 +1165,16 @@ def test_astype(self): got = ref.run(None, {"A": x}) self.assertEqualArray(z, got[0]) + def test_astype_dtype(self): + f = absolute_inline(copy_inline(Input("A")).astype(DType(7))) + self.assertIsInstance(f, Var) + onx = f.to_onnx(constraints={"A": Float64[None]}) + x = np.array([[-5.4, 6.6]], dtype=np.float64) + z = np.abs(x.astype(np.int64)) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"A": x}) + self.assertEqualArray(z, got[0]) + def test_astype_int(self): f = absolute_inline(copy_inline(Input("A")).astype(1)) self.assertIsInstance(f, Var) @@ -1413,6 +1433,9 @@ def test_einsum(self): lambda x, y: np.einsum(equation, x, y), ) + def test_equal(self): + self.common_test_inline_bin(equal_inline, np.equal) + @unittest.skipIf(scipy is None, reason="scipy is not installed.") def test_erf(self): self.common_test_inline(erf_inline, scipy.special.erf) @@ -1460,7 +1483,17 @@ def test_hstack(self): def test_identity(self): f = identity_inline(2, dtype=np.float64) onx = f.to_onnx(constraints={(0, False): Float64[None]}) - z = np.identity(2) + self.assertIn('name: "dtype"', str(onx)) + z = np.identity(2).astype(np.float64) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {}) + self.assertEqualArray(z, got[0]) + + def test_identity_uint8(self): + f = identity_inline(2, dtype=np.uint8) + onx = f.to_onnx(constraints={(0, False): Float64[None]}) + self.assertIn('name: "dtype"', str(onx)) + z = np.identity(2).astype(np.uint8) ref = ReferenceEvaluator(onx) got = ref.run(None, {}) self.assertEqualArray(z, got[0]) @@ -2318,7 +2351,7 @@ def compute_labels(X, centers): self.assertEqual(f.n_versions, 1) self.assertEqual(len(f.available_versions), 1) self.assertEqual(f.available_versions, [((np.float64, 2), (np.float64, 2))]) - key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2)) + key = ((DType(TensorProto.DOUBLE), 2), (DType(TensorProto.DOUBLE), 2)) onx = f.get_onnx(key) self.assertIsInstance(onx, ModelProto) self.assertRaise(lambda: f.get_onnx(2), ValueError) @@ -2379,7 +2412,12 @@ def compute_labels(X, centers, use_sqrt=False): self.assertEqualArray(got[1], dist) self.assertEqual(f.n_versions, 1) self.assertEqual(len(f.available_versions), 1) - key = ((np.dtype("float64"), 2), (np.dtype("float64"), 2), "use_sqrt", True) + key = ( + (DType(TensorProto.DOUBLE), 2), + (DType(TensorProto.DOUBLE), 2), + "use_sqrt", + True, + ) self.assertEqual(f.available_versions, [key]) onx = f.get_onnx(key) self.assertIsInstance(onx, ModelProto) @@ -2452,7 +2490,52 @@ def test_take(self): got = ref.run(None, {"A": data, "B": indices}) self.assertEqualArray(y, got[0]) + def test_numpy_all(self): + data = np.array([[1, 0], [1, 1]]).astype(np.bool_) + y = np.all(data, axis=1) + + f = all_inline(Input("A"), axis=1) + self.assertIsInstance(f, Var) + onx = f.to_onnx(constraints={"A": Bool[None]}) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"A": data}) + self.assertEqualArray(y, got[0]) + + def test_numpy_all_empty(self): + data = np.zeros((0,), dtype=np.bool_) + y = np.all(data) + + f = all_inline(Input("A")) + self.assertIsInstance(f, Var) + onx = f.to_onnx(constraints={"A": Bool[None]}) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"A": data}) + self.assertEqualArray(y, got[0]) + + @unittest.skipIf(True, reason="ReduceMin does not support shape[axis] == 0") + def test_numpy_all_empty_axis_0(self): + data = np.zeros((0, 1), dtype=np.bool_) + y = np.all(data, axis=0) + + f = all_inline(Input("A"), axis=0) + self.assertIsInstance(f, Var) + onx = f.to_onnx(constraints={"A": Bool[None]}) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"A": data}) + self.assertEqualArray(y, got[0]) + + def test_numpy_all_empty_axis_1(self): + data = np.zeros((0, 1), dtype=np.bool_) + y = np.all(data, axis=1) + + f = all_inline(Input("A"), axis=1) + self.assertIsInstance(f, Var) + onx = f.to_onnx(constraints={"A": Bool[None]}) + ref = ReferenceEvaluator(onx) + got = ref.run(None, {"A": data}) + self.assertEqualArray(y, got[0]) + if __name__ == "__main__": - TestNpx().test_take() + # TestNpx().test_numpy_all_empty_axis_0() unittest.main(verbosity=2) diff --git a/_unittests/ut_npx/test_sklearn_array_api.py b/_unittests/ut_npx/test_sklearn_array_api.py index 016a170..79120a9 100644 --- a/_unittests/ut_npx/test_sklearn_array_api.py +++ b/_unittests/ut_npx/test_sklearn_array_api.py @@ -4,7 +4,7 @@ from onnx.defs import onnx_opset_version from sklearn import config_context, __version__ as sklearn_version from sklearn.discriminant_analysis import LinearDiscriminantAnalysis -from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.ext_test_case import ExtTestCase, ignore_warnings from onnx_array_api.npx.npx_numpy_tensors import EagerNumpyTensor @@ -16,6 +16,7 @@ class TestSklearnArrayAPI(ExtTestCase): Version(sklearn_version) <= Version("1.2.2"), reason="reshape ArrayAPI not followed", ) + @ignore_warnings(DeprecationWarning) def test_sklearn_array_api_linear_discriminant(self): X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]]) y = np.array([1, 1, 1, 2, 2, 2]) @@ -26,6 +27,8 @@ def test_sklearn_array_api_linear_discriminant(self): new_x = EagerNumpyTensor(X) self.assertStartsWith("EagerNumpyTensor(array([[", repr(new_x)) with config_context(array_api_dispatch=True): + # It fails if scikit-learn <= 1.2.2 because the ArrayAPI + # is not strictly applied. got = ana.predict(new_x) self.assertEqualArray(expected, got.numpy()) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index aa1a59b..defe983 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -36,7 +36,7 @@ jobs: python -m pip install . -v -v -v displayName: 'install wheel' - script: | - python -m pytest -v + python -m pytest displayName: 'Runs Unit Tests' - task: PublishPipelineArtifact@0 inputs: @@ -87,9 +87,56 @@ jobs: black --diff . displayName: 'Black' - script: | - python -m pytest -v + python -m pytest displayName: 'Runs Unit Tests' +- job: 'TestLinuxArrayApi' + pool: + vmImage: 'ubuntu-latest' + strategy: + matrix: + Python310-Linux: + python.version: '3.11' + maxParallel: 3 + + steps: + - task: UsePythonVersion@0 + inputs: + versionSpec: '$(python.version)' + architecture: 'x64' + - script: sudo apt-get update + displayName: 'AptGet Update' + - script: python -m pip install --upgrade pip setuptools wheel + displayName: 'Install tools' + - script: pip install -r requirements.txt + displayName: 'Install Requirements' + - script: python setup.py install + displayName: 'Install onnx_array_api' + - script: | + git clone https://github.com/data-apis/array-api-tests.git + displayName: 'clone array-api-tests' + - script: | + cd array-api-tests + git submodule update --init --recursive + cd .. + displayName: 'get submodules for array-api-tests' + - script: pip install -r array-api-tests/requirements.txt + displayName: 'Install Requirements dev' + - script: | + export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy + cd array-api-tests + displayName: 'Set API' + - script: | + export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy + cd array-api-tests + python -m pytest -x array_api_tests/test_creation_functions.py::test_zeros + displayName: "test_creation_functions.py::test_zeros" + #- script: | + # export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy + # cd array-api-tests + # python -m pytest -x array_api_tests + # displayName: "all tests" + - job: 'TestLinux' pool: vmImage: 'ubuntu-latest' @@ -130,7 +177,7 @@ jobs: black --diff . displayName: 'Black' - script: | - python -m pytest -v + python -m pytest displayName: 'Runs Unit Tests' - script: | python -u setup.py bdist_wheel @@ -166,7 +213,7 @@ jobs: - script: pip install onnxmltools --no-deps displayName: 'Install onnxmltools' - script: | - python -m pytest -v + python -m pytest displayName: 'Runs Unit Tests' - script: | python -u setup.py bdist_wheel @@ -216,7 +263,7 @@ jobs: - script: pip install onnxmltools --no-deps displayName: 'Install onnxmltools' - script: | - python -m pytest -v -v + python -m pytest displayName: 'Runs Unit Tests' - script: | python -u setup.py bdist_wheel diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py new file mode 100644 index 0000000..e13b184 --- /dev/null +++ b/onnx_array_api/array_api/__init__.py @@ -0,0 +1,19 @@ +from onnx import TensorProto +from ..npx.npx_types import DType + + +def _finalize_array_api(module): + module.float16 = DType(TensorProto.FLOAT16) + module.float32 = DType(TensorProto.FLOAT) + module.float64 = DType(TensorProto.DOUBLE) + module.int8 = DType(TensorProto.INT8) + module.int16 = DType(TensorProto.INT16) + module.int32 = DType(TensorProto.INT32) + module.int64 = DType(TensorProto.INT64) + module.uint8 = DType(TensorProto.UINT8) + module.uint16 = DType(TensorProto.UINT16) + module.uint32 = DType(TensorProto.UINT32) + module.uint64 = DType(TensorProto.UINT64) + module.bfloat16 = DType(TensorProto.BFLOAT16) + setattr(module, "bool", DType(TensorProto.BOOL)) + setattr(module, "str", DType(TensorProto.STRING)) diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py new file mode 100644 index 0000000..8d136c4 --- /dev/null +++ b/onnx_array_api/array_api/_onnx_common.py @@ -0,0 +1,50 @@ +from typing import Any, Optional +import numpy as np +from ..npx.npx_types import DType +from ..npx.npx_array_api import BaseArrayApi +from ..npx.npx_functions import ( + copy as copy_inline, +) + + +def template_asarray( + TEagerTensor: type, + a: Any, + dtype: Optional[DType] = None, + order: Optional[str] = None, + like: Any = None, + copy: bool = False, +) -> Any: + """ + Converts anything into an array. + """ + if order not in ("C", None): + raise NotImplementedError(f"asarray is not implemented for order={order!r}.") + if like is not None: + raise NotImplementedError( + f"asarray is not implemented for like != None (type={type(like)})." + ) + if isinstance(a, BaseArrayApi): + if copy: + if dtype is None: + return copy_inline(a) + return copy_inline(a).astype(dtype=dtype) + if dtype is None: + return a + return a.astype(dtype=dtype) + + if isinstance(a, int): + v = TEagerTensor(np.array(a, dtype=np.int64)) + elif isinstance(a, float): + v = TEagerTensor(np.array(a, dtype=np.float32)) + elif isinstance(a, bool): + v = TEagerTensor(np.array(a, dtype=np.bool_)) + elif isinstance(a, str): + v = TEagerTensor(np.array(a, dtype=np.str_)) + else: + raise RuntimeError(f"Unexpected type {type(a)} for the first input.") + if dtype is not None: + vt = v.astype(dtype) + else: + vt = v + return vt diff --git a/onnx_array_api/array_api/onnx_numpy.py b/onnx_array_api/array_api/onnx_numpy.py new file mode 100644 index 0000000..79b339d --- /dev/null +++ b/onnx_array_api/array_api/onnx_numpy.py @@ -0,0 +1,70 @@ +""" +Array API valid for an :class:`EagerNumpyTensor`. +""" +from typing import Any, Optional +import numpy as np +from onnx import TensorProto +from ..npx.npx_functions import ( + all, + abs, + absolute, + astype, + equal, + isdtype, + reshape, + take, +) +from ..npx.npx_functions import zeros as generic_zeros +from ..npx.npx_numpy_tensors import EagerNumpyTensor +from ..npx.npx_types import DType, ElemType, TensorType, OptParType +from ._onnx_common import template_asarray +from . import _finalize_array_api + +__all__ = [ + "abs", + "absolute", + "all", + "asarray", + "astype", + "equal", + "isdtype", + "reshape", + "take", + "zeros", +] + + +def asarray( + a: Any, + dtype: Optional[DType] = None, + order: Optional[str] = None, + like: Any = None, + copy: bool = False, +) -> EagerNumpyTensor: + """ + Converts anything into an array. + """ + return template_asarray( + EagerNumpyTensor, a, dtype=dtype, order=order, like=like, copy=copy + ) + + +def zeros( + shape: TensorType[ElemType.int64, "I", (None,)], + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + if isinstance(shape, tuple): + return generic_zeros( + EagerNumpyTensor(np.array(shape, dtype=np.int64)), dtype=dtype, order=order + ) + return generic_zeros(shape, dtype=dtype, order=order) + + +def _finalize(): + from . import onnx_numpy + + _finalize_array_api(onnx_numpy) + + +_finalize() diff --git a/onnx_array_api/array_api/onnx_ort.py b/onnx_array_api/array_api/onnx_ort.py new file mode 100644 index 0000000..505efdf --- /dev/null +++ b/onnx_array_api/array_api/onnx_ort.py @@ -0,0 +1,54 @@ +""" +Array API valid for an :class:`EagerOrtTensor`. +""" +from typing import Optional, Any +from ..ort.ort_tensors import EagerOrtTensor +from ..npx.npx_types import DType +from ..npx.npx_functions import ( + all, + abs, + absolute, + astype, + equal, + isdtype, + reshape, + take, +) +from ._onnx_common import template_asarray +from . import _finalize_array_api + +__all__ = [ + "all", + "abs", + "absolute", + "asarray", + "astype", + "equal", + "isdtype", + "reshape", + "take", +] + + +def asarray( + a: Any, + dtype: Optional[DType] = None, + order: Optional[str] = None, + like: Any = None, + copy: bool = False, +) -> EagerOrtTensor: + """ + Converts anything into an array. + """ + return template_asarray( + EagerOrtTensor, a, dtype=dtype, order=order, like=like, copy=copy + ) + + +def _finalize(): + from . import onnx_ort + + _finalize_array_api(onnx_ort) + + +_finalize() diff --git a/onnx_array_api/npx/npx_array_api.py b/onnx_array_api/npx/npx_array_api.py index d5b2096..58968ae 100644 --- a/onnx_array_api/npx/npx_array_api.py +++ b/onnx_array_api/npx/npx_array_api.py @@ -20,15 +20,9 @@ class BaseArrayApi: def __array_namespace__(self, api_version: Optional[str] = None): """ - Returns the module holding all the available functions. + This method must be overloaded. """ - if api_version is None or api_version == "2022.12": - from onnx_array_api.npx import npx_functions - - return npx_functions - raise ValueError( - f"Unable to return an implementation for api_version={api_version!r}." - ) + raise NotImplementedError("Method '__array_namespace__' must be implemented.") def generic_method(self, method_name, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError( @@ -134,7 +128,7 @@ def T(self) -> "BaseArrayApi": return self.generic_method("T") def astype(self, dtype: Any) -> "BaseArrayApi": - return self.generic_method("astype", dtype) + return self.generic_method("astype", dtype=dtype) @property def shape(self) -> "BaseArrayApi": diff --git a/onnx_array_api/npx/npx_core_api.py b/onnx_array_api/npx/npx_core_api.py index cc3802a..05cb0bb 100644 --- a/onnx_array_api/npx/npx_core_api.py +++ b/onnx_array_api/npx/npx_core_api.py @@ -5,7 +5,7 @@ from onnx import FunctionProto, ModelProto, NodeProto from .npx_tensors import EagerTensor -from .npx_types import ElemType, OptParType, ParType, TupleType +from .npx_types import DType, ElemType, OptParType, ParType, TupleType from .npx_var import Cst, Input, ManyIdentity, Par, Var @@ -74,7 +74,7 @@ def _process_parameter(fn, sig, k, v, new_pars, inline): parent_op=(fn.__module__, fn.__name__, 0), ) return - if isinstance(v, (int, float, str, tuple)): + if isinstance(v, (int, float, str, tuple, DType)): if inline: new_pars[k] = v else: diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index f335bd0..b55cf4d 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -1,14 +1,13 @@ -from typing import Any, Optional, Tuple, Union +from typing import Optional, Tuple, Union import array_api_compat.numpy as np_array_api import numpy as np from onnx import FunctionProto, ModelProto, NodeProto, TensorProto -from onnx.helper import np_dtype_to_tensor_dtype +from onnx.helper import make_tensor, np_dtype_to_tensor_dtype, tensor_dtype_to_np_dtype from onnx.numpy_helper import from_array from .npx_constants import FUNCTION_DOMAIN from .npx_core_api import cst, make_tuple, npxapi_inline, npxapi_no_inline, var -from .npx_tensors import BaseArrayApi from .npx_types import ( DType, ElemType, @@ -43,6 +42,40 @@ def absolute( return var(x, op="Abs") +@npxapi_inline +def all( + x: TensorType[ElemType.bool_, "T"], + axis: Optional[TensorType[ElemType.int64, "I"]] = None, + keepdims: ParType[int] = 0, +) -> TensorType[ElemType.bool_, "T"]: + """ + See :func:`numpy.all`. + If input x is empty, the answer is True. + """ + # size = var(x, op="Size") + # empty = var(size, cst(np.array(0, dtype=np.int64)), op="Equal") + + # z = make_tensor_value_info("Z", TensorProto.BOOL, [1]) + # g1 = make_graph([make_node("Constant", [], ["Z"], value_bool=[True])], [], [z]) + + xi = var(x, op="Cast", to=TensorProto.INT64) + + if axis is None: + new_shape = cst(np.array([-1], dtype=np.int64)) + xifl = var(xi, new_shape, op="Reshape") + # in case xifl is empty, we need to add one element + one = cst(np.array([1], dtype=np.int64)) + xifl1 = var(xifl, one, op="Concat", axis=0) + red = xifl1.min(keepdims=keepdims) + else: + if isinstance(axis, int): + axis = [axis] + if isinstance(axis, (tuple, list)): + axis = cst(np.array(axis, dtype=np.int64)) + red = xi.min(axis, keepdims=keepdims) + return var(red, cst(1), op="Equal") + + @npxapi_inline def arccos(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]: "See :func:`numpy.arccos`." @@ -159,30 +192,9 @@ def arctanh( return var(x, op="Atanh") -def asarray( - a: Any, - dtype: Any = None, - order: Optional[str] = None, - like: Any = None, - copy: bool = False, -): - """ - Converts anything into an array. - """ - if dtype is not None: - raise RuntimeError("Method 'astype' should be used to change the type.") - if order is not None: - raise NotImplementedError(f"order={order!r} not implemented.") - if isinstance(a, BaseArrayApi): - if copy: - return a.__class__(a, copy=copy) - return a - raise NotImplementedError(f"asarray not implemented for type {type(a)}.") - - @npxapi_inline def astype( - a: TensorType[ElemType.numerics, "T1"], dtype: OptParType[int] = 1 + a: TensorType[ElemType.numerics, "T1"], dtype: OptParType[DType] = 1 ) -> TensorType[ElemType.numerics, "T2"]: """ Cast an array. @@ -335,6 +347,14 @@ def einsum( return var(*x, op="Einsum", equation=equation) +@npxapi_inline +def equal( + x: TensorType[ElemType.allowed, "T"], y: TensorType[ElemType.allowed, "T"] +) -> TensorType[ElemType.bool_, "T1"]: + "See :func:`numpy.isnan`." + return var(x, y, op="Equal") + + @npxapi_inline def erf(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]: "See :func:`scipy.special.erf`." @@ -382,18 +402,20 @@ def hstack( @npxapi_inline -def copy(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]: +def copy(x: TensorType[ElemType.allowed, "T"]) -> TensorType[ElemType.allowed, "T"]: "Makes a copy." return var(x, op="Identity") @npxapi_inline -def identity(n: ParType[int], dtype=None) -> TensorType[ElemType.numerics, "T"]: +def identity( + n: ParType[int], dtype: OptParType[DType] = None +) -> TensorType[ElemType.numerics, "T"]: "Makes a copy." - val = np.array([n, n], dtype=np.int64) - shape = cst(val) model = var( - shape, op="ConstantOfShape", value=from_array(np.array([0], dtype=np.int64)) + cst(np.array([n, n], dtype=np.int64)), + op="ConstantOfShape", + value=from_array(np.array([0], dtype=np.int64)), ) v = var(model, dtype=dtype, op="EyeLike") return v @@ -401,17 +423,22 @@ def identity(n: ParType[int], dtype=None) -> TensorType[ElemType.numerics, "T"]: @npxapi_no_inline def isdtype( - dtype: DType, kind: Union[DType, str, Tuple[Union[DType, str], ...]] + dtype: ParType[DType], kind: Union[DType, str, Tuple[Union[DType, str], ...]] ) -> bool: """ See :epkg:`BaseArrayAPI:isdtype`. This function is not converted into an onnx graph. """ + if isinstance(dtype, DType): + dti = tensor_dtype_to_np_dtype(dtype.code) + return np_array_api.isdtype(dti, kind) + if isinstance(dtype, int): + raise TypeError(f"Unexpected type {type(dtype)}.") return np_array_api.isdtype(dtype, kind) @npxapi_inline -def isnan(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T"]: +def isnan(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.bool_, "T1"]: "See :func:`numpy.isnan`." return var(x, op="IsNaN") @@ -625,3 +652,23 @@ def where( ) -> TensorType[ElemType.numerics, "T"]: "See :func:`numpy.where`." return var(cond, x, y, op="Where") + + +@npxapi_inline +def zeros( + shape: TensorType[ElemType.int64, "I", (None,)], + dtype: OptParType[DType] = DType(TensorProto.FLOAT), + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + """ + Implements :func:`numpy.zeros`. + """ + if order != "C": + raise RuntimeError(f"order={order!r} != 'C' not supported.") + if dtype is None: + dtype = DType(TensorProto.FLOAT) + return var( + shape, + value=make_tensor(name="zero", data_type=dtype.code, dims=[1], vals=[0]), + op="ConstantOfShape", + ) diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index 92c1412..ec91b91 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -38,6 +38,7 @@ rename_in_onnx_graph, ) from .npx_types import ( + DType, ElemType, OptParType, ParType, @@ -226,6 +227,8 @@ def make_node( protos.append(att) elif v.value is not None: new_kwargs[k] = v.value + elif isinstance(v, DType): + new_kwargs[k] = v.code else: new_kwargs[k] = v @@ -337,7 +340,7 @@ def _io( if tensor_type.shape is None: type_proto = TypeProto() tensor_type_proto = type_proto.tensor_type - tensor_type_proto.elem_type = tensor_type.dtypes[0].dtype + tensor_type_proto.elem_type = tensor_type.dtypes[0].dtype.code value_info_proto = ValueInfoProto() value_info_proto.name = name # tensor_type_proto.shape.dim.extend([]) @@ -348,7 +351,7 @@ def _io( # with fixed rank. This can be changed here and in methods # `make_key`. shape = [None for _ in tensor_type.shape] - info = make_tensor_value_info(name, tensor_type.dtypes[0].dtype, shape) + info = make_tensor_value_info(name, tensor_type.dtypes[0].dtype.code, shape) # check_value_info fails if the shape is left undefined check_value_info(info, self.check_context) return info diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index 6b6bfca..85b52d4 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -5,7 +5,7 @@ import numpy as np from .npx_tensors import EagerTensor, JitTensor -from .npx_types import TensorType +from .npx_types import DType, TensorType from .npx_var import Cst, Input, Var logger = getLogger("onnx-array-api") @@ -131,7 +131,7 @@ def make_key(*values, **kwargs): for iv, v in enumerate(values): if isinstance(v, (Var, EagerTensor, JitTensor)): res.append(v.key) - elif isinstance(v, (int, float)): + elif isinstance(v, (int, float, DType)): res.append(v) elif isinstance(v, slice): res.append(("slice", v.start, v.stop, v.step)) @@ -153,7 +153,7 @@ def make_key(*values, **kwargs): ) if kwargs: for k, v in sorted(kwargs.items()): - if isinstance(v, (int, float, str, type)): + if isinstance(v, (int, float, str, type, DType)): res.append(k) res.append(v) elif isinstance(v, tuple): @@ -168,6 +168,8 @@ def make_key(*values, **kwargs): else: newv.append(t) res.append(tuple(newv)) + elif v is None and k in {"dtype"}: + continue else: raise TypeError( f"Type {type(v)} is not yet supported, " @@ -193,6 +195,12 @@ def to_jit(self, *values, **kwargs): constraints = {} new_kwargs = {} for i, (v, iname) in enumerate(zip(values, names)): + if i < len(annot_values) and not isinstance(annot_values[i], type): + raise TypeError( + f"annotation {i} is not a type but is {annot_values[i]!r}." + f"for function {self.f} " + f"from module {self.f.__module__!r}." + ) if isinstance(v, (EagerTensor, JitTensor)) and ( i >= len(annot_values) or issubclass(annot_values[i], TensorType) ): @@ -250,7 +258,7 @@ def to_jit(self, *values, **kwargs): kwargs = new_kwargs else: kwargs = kwargs.copy() - kwargs.update(kwargs) + kwargs.update(new_kwargs) var = self.f(*inputs, **kwargs) @@ -336,7 +344,13 @@ def jit_call(self, *values, **kwargs): self.info("+", "jit_call") if self.input_to_kwargs_ is None: # No jitting was ever called. - onx, fct = self.to_jit(*values, **kwargs) + try: + onx, fct = self.to_jit(*values, **kwargs) + except Exception as e: + raise RuntimeError( + f"ERROR with self.f={self.f}, " + f"values={values!r}, kwargs={kwargs!r}" + ) from e if self.input_to_kwargs_ is None: raise RuntimeError( f"Attribute 'input_to_kwargs_' should be set for " @@ -520,6 +534,8 @@ def _preprocess_constants(self, *args): elif isinstance(n, (int, float)): new_args.append(self.tensor_class(np.array(n))) modified = True + elif isinstance(n, DType): + new_args.append(n) elif n in (int, float): # usually used to cast new_args.append(n) @@ -554,7 +570,17 @@ def __call__(self, *args, already_eager=False, **kwargs): lambda t: t is not None and not isinstance( t, - (EagerTensor, Cst, int, float, tuple, slice, type, np.ndarray), + ( + EagerTensor, + Cst, + int, + float, + tuple, + slice, + type, + np.ndarray, + DType, + ), ), args, ) diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index 3197f60..e1a0c10 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -1,11 +1,13 @@ -from typing import Any, Callable, List, Tuple +from typing import Any, Callable, List, Optional, Tuple import numpy as np from onnx import ModelProto +from onnx.helper import np_dtype_to_tensor_dtype from onnx.reference import ReferenceEvaluator +from .npx_numpy_tensors_ops import ConstantOfShape from .npx_tensors import EagerTensor, JitTensor -from .npx_types import TensorType +from .npx_types import DType, TensorType class NumpyTensor: @@ -24,7 +26,7 @@ class Evaluator: """ def __init__(self, tensor_class: type, input_names: List[str], onx: ModelProto): - self.ref = ReferenceEvaluator(onx) + self.ref = ReferenceEvaluator(onx, new_ops=[ConstantOfShape]) self.input_names = input_names self.tensor_class = tensor_class @@ -54,17 +56,18 @@ def __init__(self, tensor: np.ndarray): elif isinstance( tensor, ( - np.int64, + np.float16, np.float32, np.float64, + np.int64, np.int32, - np.float16, - np.int8, np.int16, - np.uint8, - np.uint16, - np.uint32, + np.int8, np.uint64, + np.uint32, + np.uint16, + np.uint8, + np.bool_, ), ): self._tensor = np.array(tensor) @@ -80,9 +83,9 @@ def numpy(self): return self._tensor @property - def dtype(self) -> Any: + def dtype(self) -> DType: "Returns the element type of this tensor." - return self._tensor.dtype + return DType(np_dtype_to_tensor_dtype(self._tensor.dtype)) @property def key(self) -> Any: @@ -171,7 +174,17 @@ class EagerNumpyTensor(NumpyTensor, EagerTensor): Defines a value for a specific backend. """ - pass + def __array_namespace__(self, api_version: Optional[str] = None): + """ + Returns the module holding all the available functions. + """ + if api_version is None or api_version == "2022.12": + from onnx_array_api.array_api import onnx_numpy + + return onnx_numpy + raise ValueError( + f"Unable to return an implementation for api_version={api_version!r}." + ) class JitNumpyTensor(NumpyTensor, JitTensor): diff --git a/onnx_array_api/npx/npx_numpy_tensors_ops.py b/onnx_array_api/npx/npx_numpy_tensors_ops.py new file mode 100644 index 0000000..5278019 --- /dev/null +++ b/onnx_array_api/npx/npx_numpy_tensors_ops.py @@ -0,0 +1,46 @@ +import numpy as np + +from onnx.reference.op_run import OpRun + + +class ConstantOfShape(OpRun): + @staticmethod + def _process(value): + cst = value[0] if isinstance(value, np.ndarray) else value + if isinstance(cst, int): + cst = np.int64(cst) + elif isinstance(cst, float): + cst = np.float64(cst) + elif cst is None: + cst = np.float32(0) + if not isinstance( + cst, + ( + np.float16, + np.float32, + np.float64, + np.int64, + np.int32, + np.int16, + np.int8, + np.uint64, + np.uint32, + np.uint16, + np.uint8, + np.bool_, + ), + ): + raise TypeError(f"value must be a real not {type(cst)}") + return cst + + def _run(self, data, value=None): + cst = self._process(value) + try: + res = np.full(tuple(data), cst) + except TypeError as e: + raise RuntimeError( + f"Unable to create a constant of shape " + f"{data!r} with value {cst!r} " + f"(raw value={value!r})." + ) from e + return (res,) diff --git a/onnx_array_api/npx/npx_tensors.py b/onnx_array_api/npx/npx_tensors.py index 136def5..e1e4b21 100644 --- a/onnx_array_api/npx/npx_tensors.py +++ b/onnx_array_api/npx/npx_tensors.py @@ -1,8 +1,9 @@ -from typing import Any +from typing import Any, Union import numpy as np from onnx.helper import np_dtype_to_tensor_dtype +from .npx_types import DType, ElemType, ParType, TensorType from .npx_array_api import BaseArrayApi, ArrayApiError @@ -73,7 +74,9 @@ def _getitem_impl_var(obj, index, method_name=None): return meth(obj, index) @staticmethod - def _astype_impl(x, dtype: int = None, method_name=None): + def _astype_impl( + x: TensorType[ElemType.allowed, "T1"], dtype: ParType[DType], method_name=None + ) -> TensorType[ElemType.allowed, "T2"]: # avoids circular imports. if dtype is None: raise ValueError("dtype cannot be None.") @@ -131,9 +134,11 @@ def _generic_method_operator(self, method_name, *args: Any, **kwargs: Any) -> An new_args = [] for a in args: if isinstance(a, np.ndarray): - new_args.append(self.__class__(a.astype(self.dtype))) + new_args.append(self.__class__(a.astype(self.dtype.np_dtype))) elif isinstance(a, (int, float)): - new_args.append(self.__class__(np.array([a]).astype(self.dtype))) + new_args.append( + self.__class__(np.array([a]).astype(self.dtype.np_dtype)) + ) else: new_args.append(a) @@ -179,18 +184,17 @@ def _np_dtype_to_tensor_dtype(dtype): dtype = np.dtype("float64") return np_dtype_to_tensor_dtype(dtype) - def _generic_method_astype(self, method_name, *args: Any, **kwargs: Any) -> Any: + def _generic_method_astype( + self, method_name, dtype: Union[DType, "Var"], **kwargs: Any + ) -> Any: # avoids circular imports. from .npx_jit_eager import eager_onnx from .npx_var import Var - if len(args) != 1: - raise ValueError(f"astype takes only one argument not {len(args)}.") - dtype = ( - args[0] - if isinstance(args[0], (int, Var)) - else self._np_dtype_to_tensor_dtype(args[0]) + dtype + if isinstance(dtype, (DType, Var)) + else self._np_dtype_to_tensor_dtype(dtype) ) eag = eager_onnx(EagerTensor._astype_impl, self.__class__, bypass_eager=True) res = eag(self, dtype, method_name=method_name, already_eager=True, **kwargs) diff --git a/onnx_array_api/npx/npx_types.py b/onnx_array_api/npx/npx_types.py index a38d53f..aa335bd 100644 --- a/onnx_array_api/npx/npx_types.py +++ b/onnx_array_api/npx/npx_types.py @@ -1,7 +1,8 @@ from typing import Any, Tuple, Union import numpy as np -from onnx import AttributeProto +from onnx import AttributeProto, TensorProto +from onnx.helper import np_dtype_to_tensor_dtype, tensor_dtype_to_np_dtype class WrapperType: @@ -14,9 +15,80 @@ class WrapperType: class DType(WrapperType): """ - Annotated type for dtype. + Type of the element type returned by tensors + following the :epkg:`Array API`. + + :param code: element type based on onnx definition """ + __slots__ = ["code_"] + + def __init__(self, code: int): + self.code_ = code + + def __repr__(self) -> str: + "usual" + return f"DType({self.code_})" + + def __str__(self) -> str: + "usual" + return f"DT{self.code_}" + + def __hash__(self) -> int: + return self.code_ + + @property + def code(self) -> int: + return self.code_ + + @property + def np_dtype(self) -> "np.dtype": + return tensor_dtype_to_np_dtype(self.code_) + + def __eq__(self, dt: "DType") -> bool: + "Compares two types." + if dt.__class__ is DType: + return self.code_ == dt.code_ + if isinstance(dt, (int, bool, str)): + return False + if dt is str: + return self.code_ == TensorProto.STRING + if dt is bool: + return self.code_ == TensorProto.BOOL + if dt in ElemType.numpy_map: + dti = ElemType.numpy_map[dt] + return self.code_ == dti.code_ + try: + dti = np_dtype_to_tensor_dtype(dt) + except KeyError: + raise TypeError(f"dt must be DType not {type(dt)} - {dt!r}.") + return self.code_ == dti + + def __lt__(self, dt: "DType") -> bool: + "Compares two types." + if dt.__class__ is DType: + return self.code_ < dt.code_ + if isinstance(dt, int): + raise TypeError(f"dt must be DType not {type(dt)}.") + try: + dti = np_dtype_to_tensor_dtype(dt) + except KeyError: + raise TypeError(f"dt must be DType not {type(dt)} - {dt}.") + return self.code_ < dti + + @classmethod + def type_name(cls) -> str: + "Returns its full name." + raise NotImplementedError() + + +class _DType2(DType): + "Wraps an into a different type." + pass + + +class _DTypes(DType): + "Wraps an into a different type." pass @@ -27,22 +99,23 @@ class ElemTypeCstInner(WrapperType): __slots__ = [] - undefined = 0 - bool_ = 9 - int8 = 3 - int16 = 5 - int32 = 6 - int64 = 7 - uint8 = 2 - uint16 = 4 - uint32 = 12 - uint64 = 13 - float16 = 10 - float32 = 1 - float64 = 11 - bfloat16 = 16 - complex64 = 14 - complex128 = 15 + undefined = DType(0) + bool_ = DType(9) + int8 = DType(3) + int16 = DType(5) + int32 = DType(6) + int64 = DType(7) + uint8 = DType(2) + uint16 = DType(4) + uint32 = DType(12) + uint64 = DType(13) + float16 = DType(10) + float32 = DType(1) + float64 = DType(11) + bfloat16 = DType(16) + complex64 = DType(14) + complex128 = DType(15) + str_ = DType(8) class ElemTypeCstSet(ElemTypeCstInner): @@ -50,7 +123,7 @@ class ElemTypeCstSet(ElemTypeCstInner): Sets of element types. """ - allowed = set(range(1, 17)) + allowed = set(DType(i) for i in range(1, 17)) ints = { ElemTypeCstInner.int8, @@ -85,13 +158,15 @@ class ElemTypeCstSet(ElemTypeCstInner): ElemTypeCstInner.float64, } + strings = {ElemTypeCstInner.str_} + @staticmethod def combined(type_set): "Combines all types into a single integer by using power of 2." s = 0 for dt in type_set: - s += 1 << dt - return s + s += 1 << dt.code + return _DTypes(s) class ElemTypeCst(ElemTypeCstSet): @@ -99,45 +174,47 @@ class ElemTypeCst(ElemTypeCstSet): Combination of element types. """ - Undefined = 0 - Bool = 1 << ElemTypeCstInner.bool_ - Int8 = 1 << ElemTypeCstInner.int8 - Int16 = 1 << ElemTypeCstInner.int16 - Int32 = 1 << ElemTypeCstInner.int32 - Int64 = 1 << ElemTypeCstInner.int64 - UInt8 = 1 << ElemTypeCstInner.uint8 - UInt16 = 1 << ElemTypeCstInner.uint16 - UInt32 = 1 << ElemTypeCstInner.uint32 - UInt64 = 1 << ElemTypeCstInner.uint64 - BFloat16 = 1 << ElemTypeCstInner.bfloat16 - Float16 = 1 << ElemTypeCstInner.float16 - Float32 = 1 << ElemTypeCstInner.float32 - Float64 = 1 << ElemTypeCstInner.float64 - Complex64 = 1 << ElemTypeCstInner.complex64 - Complex128 = 1 << ElemTypeCstInner.complex128 + Undefined = _DType2(0) + Bool = _DType2(1 << ElemTypeCstInner.bool_.code) + Int8 = _DType2(1 << ElemTypeCstInner.int8.code) + Int16 = _DType2(1 << ElemTypeCstInner.int16.code) + Int32 = _DType2(1 << ElemTypeCstInner.int32.code) + Int64 = _DType2(1 << ElemTypeCstInner.int64.code) + UInt8 = _DType2(1 << ElemTypeCstInner.uint8.code) + UInt16 = _DType2(1 << ElemTypeCstInner.uint16.code) + UInt32 = _DType2(1 << ElemTypeCstInner.uint32.code) + UInt64 = _DType2(1 << ElemTypeCstInner.uint64.code) + BFloat16 = _DType2(1 << ElemTypeCstInner.bfloat16.code) + Float16 = _DType2(1 << ElemTypeCstInner.float16.code) + Float32 = _DType2(1 << ElemTypeCstInner.float32.code) + Float64 = _DType2(1 << ElemTypeCstInner.float64.code) + Complex64 = _DType2(1 << ElemTypeCstInner.complex64.code) + Complex128 = _DType2(1 << ElemTypeCstInner.complex128.code) + String = _DType2(1 << ElemTypeCstInner.str_.code) Numerics = ElemTypeCstSet.combined(ElemTypeCstSet.numerics) Floats = ElemTypeCstSet.combined(ElemTypeCstSet.floats) Ints = ElemTypeCstSet.combined(ElemTypeCstSet.ints) + Strings = ElemTypeCstSet.combined(ElemTypeCstSet.strings) class ElemType(ElemTypeCst): """ Allowed element type based on numpy dtypes. - :param dtype: integer or a string + :param dtype: DType or a string """ names_int = { att: getattr(ElemTypeCstInner, att) for att in dir(ElemTypeCstInner) - if isinstance(getattr(ElemTypeCstInner, att), int) + if isinstance(getattr(ElemTypeCstInner, att), DType) } int_names = { getattr(ElemTypeCstInner, att): att for att in dir(ElemTypeCstInner) - if isinstance(getattr(ElemTypeCstInner, att), int) + if isinstance(getattr(ElemTypeCstInner, att), DType) } set_names = { @@ -150,24 +227,24 @@ class ElemType(ElemTypeCst): **{ getattr(np, att): getattr(ElemTypeCst, att) for att in dir(ElemTypeCst) - if isinstance(getattr(ElemTypeCst, att), int) and hasattr(np, att) + if isinstance(getattr(ElemTypeCst, att), DType) and hasattr(np, att) }, **{ np.dtype(att): getattr(ElemTypeCst, att) for att in dir(ElemTypeCst) - if isinstance(getattr(ElemTypeCst, att), int) and hasattr(np, att) + if isinstance(getattr(ElemTypeCst, att), DType) and hasattr(np, att) }, } __slots__ = ["dtype"] @classmethod - def __class_getitem__(cls, dtype: Union[str, int]): + def __class_getitem__(cls, dtype: Union[str, DType]): if isinstance(dtype, str): dtype = ElemType.names_int[dtype] elif dtype in ElemType.numpy_map: dtype = ElemType.numpy_map[dtype] - elif dtype == 0: + elif dtype == DType(0): pass elif dtype not in ElemType.allowed: raise ValueError(f"Unexpected dtype {dtype} not in {ElemType.allowed}.") @@ -197,7 +274,10 @@ def get_set_name(cls, dtypes): tt.append(dt.dtype) dtypes = set(tt) for d in dir(cls): - if dtypes == getattr(cls, d): + att = getattr(cls, d) + if not isinstance(att, set): + continue + if dtypes == att: return d return None @@ -210,7 +290,7 @@ class ParType(WrapperType): :param optional: is optional or not """ - map_names = {int: "int", float: "float", str: "str"} + map_names = {int: "int", float: "float", str: "str", DType: "DType"} @classmethod def __class_getitem__(cls, dtype): @@ -333,7 +413,7 @@ def __class_getitem__(cls, *args): if isinstance(a, tuple): shape = a continue - if isinstance(a, int): + if isinstance(a, DType): if dtypes is not None: raise TypeError(f"Unexpected type {type(a)} in {args}.") dtypes = (a,) @@ -363,7 +443,7 @@ def __class_getitem__(cls, *args): check.append(dt) elif dt in ElemType.allowed: check.append(ElemType[dt]) - elif isinstance(dt, int): + elif isinstance(dt, DType): check.append(ElemType[dt]) else: raise TypeError(f"Unexpected type {type(dt)} in {dtypes}, args={args}.") diff --git a/onnx_array_api/npx/npx_var.py b/onnx_array_api/npx/npx_var.py index c67e0ff..ae5b732 100644 --- a/onnx_array_api/npx/npx_var.py +++ b/onnx_array_api/npx/npx_var.py @@ -6,7 +6,7 @@ from .npx_array_api import BaseArrayApi, ArrayApiError from .npx_constants import DEFAULT_OPSETS, ONNX_DOMAIN -from .npx_types import ElemType, OptParType, ParType, TensorType, TupleType +from .npx_types import DType, ElemType, OptParType, ParType, TensorType, TupleType class Par: @@ -276,7 +276,7 @@ def __init__( op: Union[ Callable, str, Tuple[str, str], FunctionProto, ModelProto, NodeProto ] = None, - dtype: type = None, + dtype: Union[type, DType] = None, inline: bool = False, n_var_outputs: Optional[int] = 1, input_indices: Optional[List[int]] = None, @@ -298,11 +298,11 @@ def __init__( self.onnx_op_kwargs = kwargs self._prefix = None - if hasattr(dtype, "type_name"): - self.dtype = dtype - elif isinstance(dtype, int): + if isinstance(dtype, DType): # regular parameter self.onnx_op_kwargs["dtype"] = dtype + elif hasattr(dtype, "type_name"): + self.dtype = dtype elif dtype is None: self.dtype = None else: diff --git a/onnx_array_api/ort/ort_tensors.py b/onnx_array_api/ort/ort_tensors.py index 63bc378..ead834d 100644 --- a/onnx_array_api/ort/ort_tensors.py +++ b/onnx_array_api/ort/ort_tensors.py @@ -3,7 +3,6 @@ import numpy as np from onnx import ModelProto, TensorProto from onnx.defs import onnx_opset_version -from onnx.helper import tensor_dtype_to_np_dtype from onnxruntime import InferenceSession, RunOptions, get_available_providers from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice from onnxruntime.capi._pybind_state import OrtMemType @@ -11,7 +10,7 @@ from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument from ..npx.npx_tensors import EagerTensor, JitTensor -from ..npx.npx_types import TensorType +from ..npx.npx_types import DType, TensorType class OrtTensor: @@ -152,9 +151,9 @@ def shape(self) -> Tuple[int, ...]: return self._tensor.shape() @property - def dtype(self) -> Any: + def dtype(self) -> DType: "Returns the element type of this tensor." - return tensor_dtype_to_np_dtype(self._tensor.element_type()) + return DType(self._tensor.element_type()) @property def key(self) -> Any: @@ -234,7 +233,17 @@ class EagerOrtTensor(OrtTensor, OrtCommon, EagerTensor): Defines a value for :epkg:`onnxruntime` as a backend. """ - pass + def __array_namespace__(self, api_version: Optional[str] = None): + """ + Returns the module holding all the available functions. + """ + if api_version is None or api_version == "2022.12": + from onnx_array_api.array_api import onnx_ort + + return onnx_ort + raise ValueError( + f"Unable to return an implementation for api_version={api_version!r}." + ) class JitOrtTensor(OrtTensor, OrtCommon, JitTensor): diff --git a/onnx_array_api/plotting/_helper.py b/onnx_array_api/plotting/_helper.py index 69ea987..48e65d9 100644 --- a/onnx_array_api/plotting/_helper.py +++ b/onnx_array_api/plotting/_helper.py @@ -11,6 +11,7 @@ ) from onnx.helper import tensor_dtype_to_np_dtype from onnx.numpy_helper import to_array +from ..npx.npx_types import DType class Graph: @@ -44,7 +45,7 @@ def __init__( self.shape = shape @property - def dtype(self) -> Any: + def dtype(self) -> DType: return self.values.dtype diff --git a/pyproject.toml b/pyproject.toml index 832a027..9ef84cc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,6 +3,7 @@ report_level = "INFO" ignore_directives = [ "autoclass", "autofunction", + "automodule", "gdot", "image-sg", "runpython", @@ -29,10 +30,13 @@ max-complexity = 10 [tool.ruff.per-file-ignores] "_doc/examples/plot_first_example.py" = ["E402", "F811"] "_doc/examples/plot_onnxruntime.py" = ["E402", "F811"] -"onnx_array_api/profiling.py" = ["E731"] +"onnx_array_api/array_api/onnx_numpy.py" = ["F821"] +"onnx_array_api/array_api/onnx_ort.py" = ["F821"] "onnx_array_api/npx/__init__.py" = ["F401", "F403"] "onnx_array_api/npx/npx_functions.py" = ["F821"] "onnx_array_api/npx/npx_functions_test.py" = ["F821"] +"onnx_array_api/npx/npx_tensors.py" = ["F821"] "onnx_array_api/npx/npx_var.py" = ["F821"] +"onnx_array_api/profiling.py" = ["E731"] "_unittests/ut_npx/test_npx.py" = ["F821"] pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy