From 523d68d4834afb2e4335dedb761de24a156c43b3 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 30 Jun 2023 09:50:24 +0200 Subject: [PATCH 1/3] Add full_like for the array API --- _unittests/onnx-numpy-skips.txt | 4 +- _unittests/test_array_api.sh | 2 +- _unittests/ut_array_api/test_onnx_numpy.py | 9 ++++- onnx_array_api/array_api/__init__.py | 1 + onnx_array_api/array_api/_onnx_common.py | 16 ++++++++ onnx_array_api/npx/npx_functions.py | 47 +++++++++++++++++++++- onnx_array_api/npx/npx_types.py | 2 +- 7 files changed, 74 insertions(+), 7 deletions(-) diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index a3eaa47..b035005 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -6,7 +6,7 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_creation_functions.py::test_empty array_api_tests/test_creation_functions.py::test_empty_like array_api_tests/test_creation_functions.py::test_eye -array_api_tests/test_creation_functions.py::test_full_like +# array_api_tests/test_creation_functions.py::test_full_like array_api_tests/test_creation_functions.py::test_linspace array_api_tests/test_creation_functions.py::test_meshgrid -array_api_tests/test_creation_functions.py::test_zeros_like +# array_api_tests/test_creation_functions.py::test_zeros_like diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh index abab39b..5df14e2 100644 --- a/_unittests/test_array_api.sh +++ b/_unittests/test_array_api.sh @@ -1,4 +1,4 @@ export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy -pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_ones_like || exit 1 +pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_full_like || exit 1 # pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index 859c802..1d197ce 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -112,7 +112,14 @@ def test_ones_like_uint16(self): expected = np.array(1, dtype=np.uint16) self.assertEqualArray(expected, z.numpy()) + def test_full_like(self): + c = EagerTensor(np.array(False)) + mat = xp.full(c, fill_value=False) + matnp = mat.numpy() + self.assertEqual(matnp.shape, tuple()) + self.assertEqulaArray(mat, matnp.numpy()) + if __name__ == "__main__": - # TestOnnxNumpy().test_ones_like() + TestOnnxNumpy().test_full_like() unittest.main(verbosity=2) diff --git a/onnx_array_api/array_api/__init__.py b/onnx_array_api/array_api/__init__.py index e1e09b8..bd762be 100644 --- a/onnx_array_api/array_api/__init__.py +++ b/onnx_array_api/array_api/__init__.py @@ -18,6 +18,7 @@ "empty", "equal", "full", + "full_like", "isdtype", "isfinite", "isinf", diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 2a67f22..b23b71f 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -20,6 +20,7 @@ abs as generic_abs, arange as generic_arange, full as generic_full, + full_like as generic_full_like, ones as generic_ones, zeros as generic_zeros, ) @@ -177,6 +178,21 @@ def full( return generic_full(shape, fill_value=value, dtype=dtype, order=order) +def full_like( + TEagerTensor: type, + x: TensorType[ElemType.allowed, "T"], + /, + fill_value: ParType[Scalar] = None, + *, + dtype: OptParType[DType] = None, + order: OptParType[str] = "C", +) -> EagerTensor[TensorType[ElemType.allowed, "TR"]]: + if dtype is None: + if isinstance(fill_value, TEagerTensor): + dtype = fill_value.dtype + return generic_full_like(x, fill_value, dtype=dtype, order=order) + + def ones( TEagerTensor: type, shape: EagerTensor[TensorType[ElemType.int64, "I", (None,)]], diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index 94de749..46e18fa 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -275,9 +275,9 @@ def astype( if dtype is int: to = DType(TensorProto.INT64) elif dtype is float: - to = DType(TensorProto.FLOAT64) + to = DType(TensorProto.DOUBLE) elif dtype is bool: - to = DType(TensorProto.FLOAT64) + to = DType(TensorProto.BOOL) elif dtype is str: to = DType(TensorProto.STRING) else: @@ -511,6 +511,49 @@ def full( return var(shape, value=value, op="ConstantOfShape") +@npxapi_inline +def full_like( + x: TensorType[ElemType.allowed, "T"], + /, + fill_value: ParType[Scalar] = None, + *, + dtype: OptParType[DType] = None, + order: OptParType[str] = "C", +) -> TensorType[ElemType.numerics, "T"]: + """ + Implements :func:`numpy.zeros`. + """ + if order != "C": + raise RuntimeError(f"order={order!r} != 'C' not supported.") + if fill_value is None: + raise TypeError("fill_value cannot be None.") + if dtype is None: + if isinstance(fill_value, bool): + dtype = DType(TensorProto.BOOL) + elif isinstance(fill_value, int): + dtype = DType(TensorProto.INT64) + elif isinstance(fill_value, float): + dtype = DType(TensorProto.DOUBLE) + else: + raise TypeError( + f"Unexpected type {type(fill_value)} for fill_value={fill_value!r} " + f"and dtype={dtype!r}." + ) + if isinstance(fill_value, (float, int, bool)): + value = make_tensor( + name="cst", data_type=dtype.code, dims=[1], vals=[fill_value] + ) + else: + raise NotImplementedError( + f"Unexpected type ({type(fill_value)} for fill_value={fill_value!r}." + ) + + v = var(x.shape, value=value, op="ConstantOfShape") + if dtype is None: + return var(v, x, op="CastLike") + return v + + @npxapi_inline def floor( x: TensorType[ElemType.numerics, "T"], / diff --git a/onnx_array_api/npx/npx_types.py b/onnx_array_api/npx/npx_types.py index fe7b287..54cc618 100644 --- a/onnx_array_api/npx/npx_types.py +++ b/onnx_array_api/npx/npx_types.py @@ -68,7 +68,7 @@ def __eq__(self, dt: "DType") -> bool: if dt is bool: return self.code_ == TensorProto.BOOL if dt is float: - return self.code_ == TensorProto.FLOAT64 + return self.code_ == TensorProto.DOUBLE if isinstance(dt, list): return False if dt in ElemType.numpy_map: From 9c1b6da517b316259b0b133dc2446c1056b65b93 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 30 Jun 2023 11:08:59 +0200 Subject: [PATCH 2/3] improvment --- .../ut_array_api/test_hypothesis_array_api.py | 6 +++--- _unittests/ut_array_api/test_onnx_numpy.py | 17 ++++++++++++++--- onnx_array_api/array_api/_onnx_common.py | 2 +- onnx_array_api/npx/npx_functions.py | 2 +- onnx_array_api/npx/npx_jit_eager.py | 15 ++++++++++----- onnx_array_api/npx/npx_numpy_tensors.py | 3 +-- onnx_array_api/reference/evaluator.py | 17 +++++++++++++++++ .../ops/op_constant_of_shape.py} | 10 ++++++++-- 8 files changed, 55 insertions(+), 17 deletions(-) rename onnx_array_api/{npx/npx_numpy_tensors_ops.py => reference/ops/op_constant_of_shape.py} (78%) diff --git a/_unittests/ut_array_api/test_hypothesis_array_api.py b/_unittests/ut_array_api/test_hypothesis_array_api.py index e29af65..fdf48f9 100644 --- a/_unittests/ut_array_api/test_hypothesis_array_api.py +++ b/_unittests/ut_array_api/test_hypothesis_array_api.py @@ -140,7 +140,7 @@ def fctonx(x, kw): if __name__ == "__main__": - cl = TestHypothesisArraysApis() - cl.setUpClass() - cl.test_scalar_strategies() + # cl = TestHypothesisArraysApis() + # cl.setUpClass() + # cl.test_scalar_strategies() unittest.main(verbosity=2) diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index 1d197ce..577f64e 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -114,12 +114,23 @@ def test_ones_like_uint16(self): def test_full_like(self): c = EagerTensor(np.array(False)) - mat = xp.full(c, fill_value=False) + expected = np.full_like(c.numpy(), fill_value=False) + mat = xp.full_like(c, fill_value=False) matnp = mat.numpy() self.assertEqual(matnp.shape, tuple()) - self.assertEqulaArray(mat, matnp.numpy()) + self.assertEqualArray(expected, matnp) + + def test_full_like_mx(self): + c = EagerTensor(np.array([], dtype=np.uint8)) + expected = np.full_like(c.numpy(), fill_value=0) + mat = xp.full_like(c, fill_value=0) + matnp = mat.numpy() + self.assertEqualArray(expected, matnp) if __name__ == "__main__": - TestOnnxNumpy().test_full_like() + import logging + + logging.basicConfig(level=logging.DEBUG) + TestOnnxNumpy().test_full_like_mx() unittest.main(verbosity=2) diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index b23b71f..52ae566 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -190,7 +190,7 @@ def full_like( if dtype is None: if isinstance(fill_value, TEagerTensor): dtype = fill_value.dtype - return generic_full_like(x, fill_value, dtype=dtype, order=order) + return generic_full_like(x, fill_value=fill_value, dtype=dtype, order=order) def ones( diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index 46e18fa..33ad74b 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -515,8 +515,8 @@ def full( def full_like( x: TensorType[ElemType.allowed, "T"], /, - fill_value: ParType[Scalar] = None, *, + fill_value: ParType[Scalar] = None, dtype: OptParType[DType] = None, order: OptParType[str] = "C", ) -> TensorType[ElemType.numerics, "T"]: diff --git a/onnx_array_api/npx/npx_jit_eager.py b/onnx_array_api/npx/npx_jit_eager.py index b49d7ce..e06c944 100644 --- a/onnx_array_api/npx/npx_jit_eager.py +++ b/onnx_array_api/npx/npx_jit_eager.py @@ -58,6 +58,7 @@ def info( kwargs: Optional[Dict[str, Any]] = None, key: Optional[Tuple[Any, ...]] = None, onx: Optional[ModelProto] = None, + output: Optional[Any] = None, ): """ Logs a status. @@ -93,6 +94,8 @@ def info( "" if args is None else str(args), "" if kwargs is None else str(kwargs), ) + if output is not None: + logger.debug("==== [%s]", output) def status(self, me: str) -> str: """ @@ -517,7 +520,7 @@ def jit_call(self, *values, **kwargs): f"f={self.f} from module {self.f.__module__!r} " f"onnx=\n---\n{text}\n---\n{self.onxs[key]}" ) from e - self.info("-", "jit_call") + self.info("-", "jit_call", output=res) return res @@ -737,11 +740,13 @@ def __call__(self, *args, already_eager=False, **kwargs): try: res = self.f(*values, **kwargs) except (AttributeError, TypeError) as e: - inp1 = ", ".join(map(str, map(type, args))) - inp2 = ", ".join(map(str, map(type, values))) + inp1 = ", ".join(map(str, map(lambda a: type(a).__name__, args))) + inp2 = ", ".join(map(str, map(lambda a: type(a).__name__, values))) raise TypeError( - f"Unexpected types, input types are {inp1} " - f"and {inp2}, kwargs={kwargs}." + f"Unexpected types, input types are args=[{inp1}], " + f"values=[{inp2}], kwargs={kwargs}. " + f"(values = self._preprocess_constants(args)) " + f"args={args}, values={values}" ) from e if isinstance(res, EagerTensor) or ( diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index cfc90f3..a106b95 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -4,7 +4,6 @@ from onnx import ModelProto, TensorProto from ..reference import ExtendedReferenceEvaluator from .._helpers import np_dtype_to_tensor_dtype -from .npx_numpy_tensors_ops import ConstantOfShape from .npx_tensors import EagerTensor, JitTensor from .npx_types import DType, TensorType @@ -36,7 +35,7 @@ def __init__( onx: ModelProto, f: Callable, ): - self.ref = ExtendedReferenceEvaluator(onx, new_ops=[ConstantOfShape]) + self.ref = ExtendedReferenceEvaluator(onx) self.input_names = input_names self.tensor_class = tensor_class self._f = f diff --git a/onnx_array_api/reference/evaluator.py b/onnx_array_api/reference/evaluator.py index 737b15d..aa26127 100644 --- a/onnx_array_api/reference/evaluator.py +++ b/onnx_array_api/reference/evaluator.py @@ -1,9 +1,18 @@ +from logging import getLogger from typing import Any, Dict, List, Optional, Union from onnx import FunctionProto, ModelProto from onnx.defs import get_schema from onnx.reference import ReferenceEvaluator from onnx.reference.op_run import OpRun from .ops.op_cast_like import CastLike_15, CastLike_19 +from .ops.op_constant_of_shape import ConstantOfShape + +import onnx + +print(onnx.__file__) + + +logger = getLogger("onnx-array-api-eval") class ExtendedReferenceEvaluator(ReferenceEvaluator): @@ -24,6 +33,7 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator): default_ops = [ CastLike_15, CastLike_19, + ConstantOfShape, ] @staticmethod @@ -88,3 +98,10 @@ def __init__( new_ops=new_ops, **kwargs, ) + + def _log(self, level: int, pattern: str, *args: List[Any]) -> None: + if level < self.verbose: + new_args = [self._log_arg(a) for a in args] + print(pattern % tuple(new_args)) + else: + logger.debug(pattern, *args) diff --git a/onnx_array_api/npx/npx_numpy_tensors_ops.py b/onnx_array_api/reference/ops/op_constant_of_shape.py similarity index 78% rename from onnx_array_api/npx/npx_numpy_tensors_ops.py rename to onnx_array_api/reference/ops/op_constant_of_shape.py index b4639ae..33308af 100644 --- a/onnx_array_api/npx/npx_numpy_tensors_ops.py +++ b/onnx_array_api/reference/ops/op_constant_of_shape.py @@ -1,12 +1,18 @@ import numpy as np - from onnx.reference.op_run import OpRun class ConstantOfShape(OpRun): @staticmethod def _process(value): - cst = value[0] if isinstance(value, np.ndarray) else value + cst = value[0] if isinstance(value, np.ndarray) and value.size > 0 else value + if isinstance(value, np.ndarray): + if len(value.shape) == 0: + cst = value + elif value.size > 0: + cst = value.ravel()[0] + else: + raise ValueError(f"Unexpected fill_value={value!r}") if isinstance(cst, bool): cst = np.bool_(cst) elif isinstance(cst, int): From 26dfac72c0132d3769db1932cc2501b837b02b93 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Sun, 2 Jul 2023 11:38:35 +0200 Subject: [PATCH 3/3] fix full_like --- _unittests/onnx-numpy-skips.txt | 3 +-- _unittests/test_array_api.sh | 2 +- _unittests/ut_array_api/test_onnx_numpy.py | 6 +++--- azure-pipelines.yml | 7 ++++--- onnx_array_api/array_api/_onnx_common.py | 2 ++ 5 files changed, 11 insertions(+), 9 deletions(-) diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index b035005..1eac9e2 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -6,7 +6,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays array_api_tests/test_creation_functions.py::test_empty array_api_tests/test_creation_functions.py::test_empty_like array_api_tests/test_creation_functions.py::test_eye -# array_api_tests/test_creation_functions.py::test_full_like array_api_tests/test_creation_functions.py::test_linspace array_api_tests/test_creation_functions.py::test_meshgrid -# array_api_tests/test_creation_functions.py::test_zeros_like +array_api_tests/test_creation_functions.py::test_zeros_like diff --git a/_unittests/test_array_api.sh b/_unittests/test_array_api.sh index 5df14e2..abab39b 100644 --- a/_unittests/test_array_api.sh +++ b/_unittests/test_array_api.sh @@ -1,4 +1,4 @@ export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy -pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_full_like || exit 1 +pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_ones_like || exit 1 # pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1 diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index 577f64e..78f8872 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -129,8 +129,8 @@ def test_full_like_mx(self): if __name__ == "__main__": - import logging + # import logging - logging.basicConfig(level=logging.DEBUG) - TestOnnxNumpy().test_full_like_mx() + # logging.basicConfig(level=logging.DEBUG) + # TestOnnxNumpy().test_full_like_mx() unittest.main(verbosity=2) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index c449f2e..709ced3 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -246,9 +246,10 @@ jobs: architecture: 'x64' - script: gcc --version displayName: 'gcc version' - - script: | - brew update - displayName: 'brew update' + #- script: brew upgrade + # displayName: 'brew upgrade' + #- script: brew update + # displayName: 'brew update' - script: export displayName: 'export' - script: gcc --version diff --git a/onnx_array_api/array_api/_onnx_common.py b/onnx_array_api/array_api/_onnx_common.py index 52ae566..98a89f2 100644 --- a/onnx_array_api/array_api/_onnx_common.py +++ b/onnx_array_api/array_api/_onnx_common.py @@ -190,6 +190,8 @@ def full_like( if dtype is None: if isinstance(fill_value, TEagerTensor): dtype = fill_value.dtype + elif isinstance(x, TEagerTensor): + dtype = x.dtype return generic_full_like(x, fill_value=fill_value, dtype=dtype, order=order) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy