diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index d0b47ab..0d3ae03 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -1,10 +1,8 @@ # API failures # see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt -array_api_tests/test_creation_functions.py::test_asarray_scalars -array_api_tests/test_creation_functions.py::test_arange +# uses __setitem__ array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_creation_functions.py::test_empty array_api_tests/test_creation_functions.py::test_empty_like -array_api_tests/test_creation_functions.py::test_eye array_api_tests/test_creation_functions.py::test_linspace array_api_tests/test_creation_functions.py::test_meshgrid diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh index 0a003c1..43301de 100644 --- a/_unittests/test_array_api.sh +++ b/_unittests/test_array_api.sh @@ -1,4 +1,4 @@ export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy -pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros_like || exit 1 +pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_eye || exit 1 # pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 diff --git a/_unittests/ut_array_api/test_hypothesis_array_api.py b/_unittests/ut_array_api/test_hypothesis_array_api.py index fdf48f9..47bb38f 100644 --- a/_unittests/ut_array_api/test_hypothesis_array_api.py +++ b/_unittests/ut_array_api/test_hypothesis_array_api.py @@ -39,6 +39,7 @@ def sh(x): class TestHypothesisArraysApis(ExtTestCase): MAX_ARRAY_SIZE = 10000 + SQRT_MAX_ARRAY_SIZE = int(10000**0.5) VERSION = "2021.12" @classmethod @@ -138,9 +139,80 @@ def fctonx(x, kw): fctonx() self.assertEqual(len(args_onxp), len(args_np)) + def test_square_sizes_strategies(self): + dtypes = dict( + integer_dtypes=self.xps.integer_dtypes(), + uinteger_dtypes=self.xps.unsigned_integer_dtypes(), + floating_dtypes=self.xps.floating_dtypes(), + numeric_dtypes=self.xps.numeric_dtypes(), + boolean_dtypes=self.xps.boolean_dtypes(), + scalar_dtypes=self.xps.scalar_dtypes(), + ) + + dtypes_onnx = dict( + integer_dtypes=self.onxps.integer_dtypes(), + uinteger_dtypes=self.onxps.unsigned_integer_dtypes(), + floating_dtypes=self.onxps.floating_dtypes(), + numeric_dtypes=self.onxps.numeric_dtypes(), + boolean_dtypes=self.onxps.boolean_dtypes(), + scalar_dtypes=self.onxps.scalar_dtypes(), + ) + + for k, vnp in dtypes.items(): + vonxp = dtypes_onnx[k] + anp = self.xps.arrays(dtype=vnp, shape=shapes(self.xps)) + aonxp = self.onxps.arrays(dtype=vonxp, shape=shapes(self.onxps)) + self.assertNotEmpty(anp) + self.assertNotEmpty(aonxp) + + args_np = [] + + kws = array_api_kwargs(k=strategies.integers(), dtype=self.xps.numeric_dtypes()) + sqrt_sizes = strategies.integers(0, self.SQRT_MAX_ARRAY_SIZE) + ncs = strategies.none() | sqrt_sizes + + @given(n_rows=sqrt_sizes, n_cols=ncs, kw=kws) + def fctnp(n_rows, n_cols, kw): + base = np.asarray(0) + e = np.eye(n_rows, n_cols) + self.assertNotEmpty(e.dtype) + self.assertIsInstance(e, base.__class__) + e = np.eye(n_rows, n_cols, **kw) + self.assertNotEmpty(e.dtype) + self.assertIsInstance(e, base.__class__) + args_np.append((n_rows, n_cols, kw)) + + fctnp() + self.assertEqual(len(args_np), 100) + + args_onxp = [] + + kws = array_api_kwargs( + k=strategies.integers(), dtype=self.onxps.numeric_dtypes() + ) + sqrt_sizes = strategies.integers(0, self.SQRT_MAX_ARRAY_SIZE) + ncs = strategies.none() | sqrt_sizes + + @given(n_rows=sqrt_sizes, n_cols=ncs, kw=kws) + def fctonx(n_rows, n_cols, kw): + base = onxp.asarray(0) + e = onxp.eye(n_rows, n_cols) + self.assertIsInstance(e, base.__class__) + self.assertNotEmpty(e.dtype) + e = onxp.eye(n_rows, n_cols, **kw) + self.assertNotEmpty(e.dtype) + self.assertIsInstance(e, base.__class__) + args_onxp.append((n_rows, n_cols, kw)) + + fctonx() + self.assertEqual(len(args_onxp), len(args_np)) + if __name__ == "__main__": # cl = TestHypothesisArraysApis() # cl.setUpClass() # cl.test_scalar_strategies() + # import logging + + # logging.basicConfig(level=logging.DEBUG) unittest.main(verbosity=2) diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index e96e324..8fa746b 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -142,10 +142,28 @@ def test_as_array(self): self.assertEqual(r.dtype, DType(TensorProto.UINT64)) self.assertEqual(r.numpy(), 9223372036854775809) + def test_eye(self): + nr, nc = xp.asarray(4), xp.asarray(4) + expected = np.eye(nr.numpy(), nc.numpy()) + got = xp.eye(nr, nc) + self.assertEqualArray(expected, got.numpy()) + + def test_eye_nosquare(self): + nr, nc = xp.asarray(4), xp.asarray(5) + expected = np.eye(nr.numpy(), nc.numpy()) + got = xp.eye(nr, nc) + self.assertEqualArray(expected, got.numpy()) + + def test_eye_k(self): + nr = xp.asarray(4) + expected = np.eye(nr.numpy(), k=1) + got = xp.eye(nr, k=1) + self.assertEqualArray(expected, got.numpy()) + if __name__ == "__main__": # import logging # logging.basicConfig(level=logging.DEBUG) - # TestOnnxNumpy().test_as_array() + TestOnnxNumpy().test_eye() unittest.main(verbosity=2) diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py index 6e4d712..1a305ca 100644 --- a/onnx_array_api/array_api/__init__.py +++ b/onnx_array_api/array_api/__init__.py @@ -17,6 +17,7 @@ "astype", "empty", "equal", + "eye", "full", "full_like", "isdtype", diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 7c2e59e..6f31d30 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -1,6 +1,7 @@ from typing import Any, Optional import warnings import numpy as np +from onnx import TensorProto with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -19,6 +20,8 @@ from ..npx.npx_functions import ( abs as generic_abs, arange as generic_arange, + copy as copy_inline, + eye as generic_eye, full as generic_full, full_like as generic_full_like, ones as generic_ones, @@ -185,6 +188,24 @@ def full( return generic_full(shape, fill_value=value, dtype=dtype, order=order) +def eye( + TEagerTensor: type, + n_rows: TensorType[ElemType.int64, "I"], + n_cols: OptTensorType[ElemType.int64, "I"] = None, + /, + *, + k: ParType[int] = 0, + dtype: ParType[DType] = DType(TensorProto.DOUBLE), +): + if isinstance(n_rows, int): + n_rows = TEagerTensor(np.array(n_rows, dtype=np.int64)) + if n_cols is None: + n_cols = n_rows + elif isinstance(n_cols, int): + n_cols = TEagerTensor(np.array(n_cols, dtype=np.int64)) + return generic_eye(n_rows, n_cols, k=k, dtype=dtype) + + def full_like( TEagerTensor: type, x: TensorType[ElemType.allowed, "T"], diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index 5c202f8..beb22b6 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -473,6 +473,30 @@ def expit( return var(x, op="Sigmoid") +@npxapi_inline +def eye( + n_rows: TensorType[ElemType.int64, "I"], + n_cols: TensorType[ElemType.int64, "I"], + /, + *, + k: ParType[int] = 0, + dtype: ParType[DType] = DType(TensorProto.DOUBLE), +): + "See :func:`numpy.eye`." + shape = cst(np.array([-1], dtype=np.int64)) + shape = var( + var(n_rows, shape, op="Reshape"), + var(n_cols, shape, op="Reshape"), + axis=0, + op="Concat", + ) + zero = zeros(shape, dtype=dtype) + res = var(zero, k=k, op="EyeLike") + if dtype is not None: + return var(res, to=dtype.code, op="Cast") + return res + + @npxapi_inline def full( shape: TensorType[ElemType.int64, "I", (None,)], diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index e8e49a2..b5333b5 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -230,6 +230,11 @@ def make_node( new_kwargs[k] = v.value elif isinstance(v, DType): new_kwargs[k] = v.code + elif isinstance(v, int): + try: + new_kwargs[k] = int(np.array(v, dtype=np.int64)) + except OverflowError: + new_kwargs[k] = int(np.iinfo(np.int64).max) else: new_kwargs[k] = v @@ -246,6 +251,11 @@ def make_node( f"Unable to create node {op!r}, with inputs={inputs}, " f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}." ) from e + except ValueError as e: + raise ValueError( + f"Unable to create node {op!r}, with inputs={inputs}, " + f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}." + ) from e for p in protos: node.attribute.append(p) if attribute_protos is not None: diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index e06c944..71799f9 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -510,11 +510,18 @@ def jit_call(self, *values, **kwargs): from ..plotting.text_plot import onnx_simple_text_plot text = onnx_simple_text_plot(self.onxs[key]) + + def catch_len(x): + try: + return len(x) + except TypeError: + return 0 + raise RuntimeError( f"Unable to run function for key={key!r}, " f"types={[type(x) for x in values]}, " f"dtypes={[getattr(x, 'dtype', type(x)) for x in values]}, " - f"shapes={[getattr(x, 'shape', len(x)) for x in values]}, " + f"shapes={[getattr(x, 'shape', catch_len(x)) for x in values]}, " f"kwargs={kwargs}, " f"self.input_to_kwargs_={self.input_to_kwargs_}, " f"f={self.f} from module {self.f.__module__!r} " diff --git a/onnx_array_api/reference/evaluator.py b/onnx_array_api/reference/evaluator.py index aa26127..77a9344 100644 --- a/onnx_array_api/reference/evaluator.py +++ b/onnx_array_api/reference/evaluator.py @@ -7,10 +7,6 @@ from .ops.op_cast_like import CastLike_15, CastLike_19 from .ops.op_constant_of_shape import ConstantOfShape -import onnx - -print(onnx.__file__) - logger = getLogger("onnx-array-api-eval")
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: