diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 9d8d98d..b5e9d88 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.2.0 +++++ +* :pr:`24`: add ExtendedReferenceEvaluator to support scenario for the Array API onnx does not support * :pr:`22`: support OrtValue in function :func:`ort_profile` * :pr:`17`: implements ArrayAPI * :pr:`3`: fixes Array API with onnxruntime and scikit-learn diff --git a/_doc/api/index.rst b/_doc/api/index.rst index 475fad6..a95b2f4 100644 --- a/_doc/api/index.rst +++ b/_doc/api/index.rst @@ -15,4 +15,5 @@ API onnx_tools ort plotting + reference tools diff --git a/_doc/api/reference.rst b/_doc/api/reference.rst new file mode 100644 index 0000000..acbf90a --- /dev/null +++ b/_doc/api/reference.rst @@ -0,0 +1,7 @@ +reference +========= + +ExtendedReferenceEvaluator +++++++++++++++++++++++++++ + +.. autoclass:: onnx_array_api.reference.ExtendedReferenceEvaluator diff --git a/_unittests/onnx-numpy-skips.txt b/_unittests/onnx-numpy-skips.txt index 9a04400..a3eaa47 100644 --- a/_unittests/onnx-numpy-skips.txt +++ b/_unittests/onnx-numpy-skips.txt @@ -9,6 +9,4 @@ array_api_tests/test_creation_functions.py::test_eye array_api_tests/test_creation_functions.py::test_full_like array_api_tests/test_creation_functions.py::test_linspace array_api_tests/test_creation_functions.py::test_meshgrid -# Issue with CastLike and bfloat16 on onnx <= 1.15.0 -# array_api_tests/test_creation_functions.py::test_ones_like array_api_tests/test_creation_functions.py::test_zeros_like diff --git a/_unittests/ut_array_api/test_onnx_numpy.py b/_unittests/ut_array_api/test_onnx_numpy.py index 9e3efb7..859c802 100644 --- a/_unittests/ut_array_api/test_onnx_numpy.py +++ b/_unittests/ut_array_api/test_onnx_numpy.py @@ -1,8 +1,7 @@ import sys import unittest -from packaging.version import Version import numpy as np -from onnx import TensorProto, __version__ as onnx_ver +from onnx import TensorProto from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.array_api import onnx_numpy as xp from onnx_array_api.npx.npx_types import DType @@ -99,10 +98,6 @@ def test_arange_int00(self): expected = expected.astype(np.int64) self.assertEqualArray(matnp, expected) - @unittest.skipIf( - Version(onnx_ver) < Version("1.15.0"), - reason="Reference implementation of CastLike is bugged.", - ) def test_ones_like_uint16(self): x = EagerTensor(np.array(0, dtype=np.uint16)) y = np.ones_like(x.numpy()) diff --git a/_unittests/ut_reference/test_backend_extended_reference_evaluator.py b/_unittests/ut_reference/test_backend_extended_reference_evaluator.py new file mode 100644 index 0000000..4bc0927 --- /dev/null +++ b/_unittests/ut_reference/test_backend_extended_reference_evaluator.py @@ -0,0 +1,239 @@ +import os +import platform +import unittest +from typing import Any +import numpy +import onnx.backend.base +import onnx.backend.test +import onnx.shape_inference +import onnx.version_converter +from onnx import ModelProto +from onnx.backend.base import Device, DeviceType +from onnx.defs import onnx_opset_version +from onnx_array_api.reference import ExtendedReferenceEvaluator + + +class ExtendedReferenceEvaluatorBackendRep(onnx.backend.base.BackendRep): + def __init__(self, session): + self._session = session + + def run(self, inputs, **kwargs): + if isinstance(inputs, numpy.ndarray): + inputs = [inputs] + if isinstance(inputs, list): + if len(inputs) == len(self._session.input_names): + feeds = dict(zip(self._session.input_names, inputs)) + else: + feeds = {} + pos_inputs = 0 + for inp, tshape in zip( + self._session.input_names, self._session.input_types + ): + shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim) + if shape == inputs[pos_inputs].shape: + feeds[inp] = inputs[pos_inputs] + pos_inputs += 1 + if pos_inputs >= len(inputs): + break + elif isinstance(inputs, dict): + feeds = inputs + else: + raise TypeError(f"Unexpected input type {type(inputs)!r}.") + outs = self._session.run(None, feeds) + return outs + + +class ExtendedReferenceEvaluatorBackend(onnx.backend.base.Backend): + @classmethod + def is_opset_supported(cls, model): # pylint: disable=unused-argument + return True, "" + + @classmethod + def supports_device(cls, device: str) -> bool: + d = Device(device) + return d.type == DeviceType.CPU # type: ignore[no-any-return] + + @classmethod + def create_inference_session(cls, model): + return ExtendedReferenceEvaluator(model) + + @classmethod + def prepare( + cls, model: Any, device: str = "CPU", **kwargs: Any + ) -> ExtendedReferenceEvaluatorBackendRep: + # if isinstance(model, ExtendedReferenceEvaluatorBackendRep): + # return model + if isinstance(model, ExtendedReferenceEvaluator): + return ExtendedReferenceEvaluatorBackendRep(model) + if isinstance(model, (str, bytes, ModelProto)): + inf = cls.create_inference_session(model) + return cls.prepare(inf, device, **kwargs) + raise TypeError(f"Unexpected type {type(model)} for model.") + + @classmethod + def run_model(cls, model, inputs, device=None, **kwargs): + rep = cls.prepare(model, device, **kwargs) + return rep.run(inputs, **kwargs) + + @classmethod + def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): + raise NotImplementedError("Unable to run the model node by node.") + + +backend_test = onnx.backend.test.BackendTest( + ExtendedReferenceEvaluatorBackend, __name__ +) + +if os.getenv("APPVEYOR"): + backend_test.exclude("(test_vgg19|test_zfnet)") +if platform.architecture()[0] == "32bit": + backend_test.exclude("(test_vgg19|test_zfnet|test_bvlc_alexnet)") +if platform.system() == "Windows": + backend_test.exclude("test_sequence_model") + +if onnx_opset_version() < 21: + backend_test.exclude( + "(test_averagepool_2d_dilations" + "|test_if*" + "|test_loop*" + "|test_scan*" + "|test_sequence_map*" + ")" + ) + +if onnx_opset_version() < 19: + backend_test.exclude( + "(test_argm[ai][nx]_default_axis_example" + "|test_argm[ai][nx]_default_axis_random" + "|test_argm[ai][nx]_keepdims_example" + "|test_argm[ai][nx]_keepdims_random" + "|test_argm[ai][nx]_negative_axis_keepdims_example" + "|test_argm[ai][nx]_negative_axis_keepdims_random" + "|test_argm[ai][nx]_no_keepdims_example" + "|test_argm[ai][nx]_no_keepdims_random" + "|test_col2im_pads" + "|test_gru_batchwise" + "|test_gru_defaults" + "|test_gru_seq_length" + "|test_gru_with_initial_bias" + "|test_layer_normalization_2d_axis1_expanded" + "|test_layer_normalization_2d_axis_negative_1_expanded" + "|test_layer_normalization_3d_axis1_epsilon_expanded" + "|test_layer_normalization_3d_axis2_epsilon_expanded" + "|test_layer_normalization_3d_axis_negative_1_epsilon_expanded" + "|test_layer_normalization_3d_axis_negative_2_epsilon_expanded" + "|test_layer_normalization_4d_axis1_expanded" + "|test_layer_normalization_4d_axis2_expanded" + "|test_layer_normalization_4d_axis3_expanded" + "|test_layer_normalization_4d_axis_negative_1_expanded" + "|test_layer_normalization_4d_axis_negative_2_expanded" + "|test_layer_normalization_4d_axis_negative_3_expanded" + "|test_layer_normalization_default_axis_expanded" + "|test_logsoftmax_large_number_expanded" + "|test_lstm_batchwise" + "|test_lstm_defaults" + "|test_lstm_with_initial_bias" + "|test_lstm_with_peepholes" + "|test_mvn" + "|test_mvn_expanded" + "|test_softmax_large_number_expanded" + "|test_operator_reduced_mean" + "|test_operator_reduced_mean_keepdim)" + ) + +# The following tests are not supported. +backend_test.exclude( + "(test_gradient" + "|test_if_opt" + "|test_loop16_seq_none" + "|test_range_float_type_positive_delta_expanded" + "|test_range_int32_type_negative_delta_expanded" + "|test_scan_sum)" +) + +if onnx_opset_version() < 21: + # The following tests are using types not supported by NumPy. + # They could be if method to_array is extended to support custom + # types the same as the reference implementation does + # (see onnx.reference.op_run.to_array_extended). + backend_test.exclude( + "(test_cast_FLOAT_to_BFLOAT16" + "|test_cast_BFLOAT16_to_FLOAT" + "|test_cast_BFLOAT16_to_FLOAT" + "|test_castlike_BFLOAT16_to_FLOAT" + "|test_castlike_FLOAT_to_BFLOAT16" + "|test_castlike_FLOAT_to_BFLOAT16_expanded" + "|test_cast_no_saturate_" + "|_to_FLOAT8" + "|_FLOAT8" + "|test_quantizelinear_e4m3fn" + "|test_quantizelinear_e5m2" + ")" + ) + + # Disable test about float 8 + backend_test.exclude( + "(test_castlike_BFLOAT16*" + "|test_cast_BFLOAT16*" + "|test_cast_no_saturate*" + "|test_cast_FLOAT_to_FLOAT8*" + "|test_cast_FLOAT16_to_FLOAT8*" + "|test_cast_FLOAT8_to_*" + "|test_castlike_BFLOAT16*" + "|test_castlike_no_saturate*" + "|test_castlike_FLOAT_to_FLOAT8*" + "|test_castlike_FLOAT16_to_FLOAT8*" + "|test_castlike_FLOAT8_to_*" + "|test_quantizelinear_e*)" + ) + +# The following tests are too slow with the reference implementation (Conv). +backend_test.exclude( + "(test_bvlc_alexnet" + "|test_densenet121" + "|test_inception_v1" + "|test_inception_v2" + "|test_resnet50" + "|test_shufflenet" + "|test_squeezenet" + "|test_vgg19" + "|test_zfnet512)" +) + +# The following tests cannot pass because they consists in generating random number. +backend_test.exclude("(test_bernoulli)") + +if onnx_opset_version() < 21: + # The following tests fail due to a bug in the backend test comparison. + backend_test.exclude( + "(test_cast_FLOAT_to_STRING|test_castlike_FLOAT_to_STRING|test_strnorm)" + ) + + # The following tests fail due to a shape mismatch. + backend_test.exclude( + "(test_center_crop_pad_crop_axes_hwc_expanded|test_lppool_2d_dilations)" + ) + + # The following tests fail due to a type mismatch. + backend_test.exclude("(test_eyelike_without_dtype)") + +# The following tests fail due to discrepancies (small but still higher than 1e-7). +backend_test.exclude("test_adam_multiple") # 1e-2 + + +# import all test cases at global scope to make them visible to python.unittest +globals().update(backend_test.test_cases) + +if __name__ == "__main__": + res = unittest.main(verbosity=2, exit=False) + tests_run = res.result.testsRun + errors = len(res.result.errors) + skipped = len(res.result.skipped) + unexpected_successes = len(res.result.unexpectedSuccesses) + expected_failures = len(res.result.expectedFailures) + print("---------------------------------") + print( + f"tests_run={tests_run} errors={errors} skipped={skipped} " + f"unexpected_successes={unexpected_successes} " + f"expected_failures={expected_failures}" + ) diff --git a/onnx_array_api/npx/npx_numpy_tensors.py b/onnx_array_api/npx/npx_numpy_tensors.py index 80f530a..ba10d79 100644 --- a/onnx_array_api/npx/npx_numpy_tensors.py +++ b/onnx_array_api/npx/npx_numpy_tensors.py @@ -1,7 +1,7 @@ from typing import Any, Callable, List, Optional, Tuple import numpy as np from onnx import ModelProto, TensorProto -from onnx.reference import ReferenceEvaluator +from ..reference import ExtendedReferenceEvaluator from .._helpers import np_dtype_to_tensor_dtype from .npx_numpy_tensors_ops import ConstantOfShape from .npx_tensors import EagerTensor, JitTensor @@ -11,7 +11,7 @@ class NumpyTensor: """ Default backend based on - :func:`onnx.reference.ReferenceEvaluator`. + :func:`onnx_array_api.reference.ExtendedReferenceEvaluator`. :param input_names: input names :param onx: onnx model @@ -19,7 +19,7 @@ class NumpyTensor: class Evaluator: """ - Wraps class :class:`onnx.reference.ReferenceEvaluator` + Wraps class :class:`onnx_array_api.reference.ExtendedReferenceEvaluator` to have a signature closer to python function. :param tensor_class: class tensor such as :class:`NumpyTensor` @@ -35,7 +35,7 @@ def __init__( onx: ModelProto, f: Callable, ): - self.ref = ReferenceEvaluator(onx, new_ops=[ConstantOfShape]) + self.ref = ExtendedReferenceEvaluator(onx, new_ops=[ConstantOfShape]) self.input_names = input_names self.tensor_class = tensor_class self._f = f diff --git a/onnx_array_api/reference/__init__.py b/onnx_array_api/reference/__init__.py new file mode 100644 index 0000000..e4db27c --- /dev/null +++ b/onnx_array_api/reference/__init__.py @@ -0,0 +1 @@ +from .evaluator import ExtendedReferenceEvaluator diff --git a/onnx_array_api/reference/evaluator.py b/onnx_array_api/reference/evaluator.py new file mode 100644 index 0000000..737b15d --- /dev/null +++ b/onnx_array_api/reference/evaluator.py @@ -0,0 +1,90 @@ +from typing import Any, Dict, List, Optional, Union +from onnx import FunctionProto, ModelProto +from onnx.defs import get_schema +from onnx.reference import ReferenceEvaluator +from onnx.reference.op_run import OpRun +from .ops.op_cast_like import CastLike_15, CastLike_19 + + +class ExtendedReferenceEvaluator(ReferenceEvaluator): + """ + This class replaces the python implementation by custom implementation. + The Array API extends many operator to all types not supported + by the onnx specifications. The evaluator allows to test + scenarios outside what an onnx backend bound to the official onnx + operators definition could do. + + :: + + from onnx.reference import ReferenceEvaluator + from onnx.reference.c_ops import Conv + ref = ReferenceEvaluator(..., new_ops=[Conv]) + """ + + default_ops = [ + CastLike_15, + CastLike_19, + ] + + @staticmethod + def filter_ops(proto, new_ops, opsets): + if opsets is None and isinstance(proto, (ModelProto, FunctionProto)): + opsets = {d.domain: d.version for d in proto.opset_import} + best = {} + renamed = {} + for cl in new_ops: + if "_" not in cl.__name__: + continue + vers = cl.__name__.split("_") + try: + v = int(vers[-1]) + except ValueError: + # not a version + continue + if opsets is not None and v > opsets.get(cl.op_domain, 1): + continue + renamed[cl.__name__] = cl + key = cl.op_domain, "_".join(vers[:-1]) + if key not in best or best[key][0] < v: + best[key] = (v, cl) + + modified = [] + for cl in new_ops: + if cl.__name__ not in renamed: + modified.append(cl) + for k, v in best.items(): + atts = {"domain": k[0]} + bases = (v[1],) + if not hasattr(v[1], "op_schema"): + atts["op_schema"] = get_schema(k[1], v[0], domain=v[1].op_domain) + new_cl = type(k[1], bases, atts) + modified.append(new_cl) + + new_ops = modified + return new_ops + + def __init__( + self, + proto: Any, + opsets: Optional[Dict[str, int]] = None, + functions: Optional[List[Union[ReferenceEvaluator, FunctionProto]]] = None, + verbose: int = 0, + new_ops: Optional[List[OpRun]] = None, + **kwargs, + ): + if new_ops is None: + new_ops = ExtendedReferenceEvaluator.default_ops + else: + new_ops = new_ops.copy() + new_ops.extend(ExtendedReferenceEvaluator.default_ops) + new_ops = ExtendedReferenceEvaluator.filter_ops(proto, new_ops, opsets) + + ReferenceEvaluator.__init__( + self, + proto, + opsets=opsets, + functions=functions, + verbose=verbose, + new_ops=new_ops, + **kwargs, + ) diff --git a/onnx_array_api/reference/ops/__init__.py b/onnx_array_api/reference/ops/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/onnx_array_api/reference/ops/__init__.py @@ -0,0 +1 @@ + diff --git a/onnx_array_api/reference/ops/op_cast_like.py b/onnx_array_api/reference/ops/op_cast_like.py new file mode 100644 index 0000000..97cc798 --- /dev/null +++ b/onnx_array_api/reference/ops/op_cast_like.py @@ -0,0 +1,38 @@ +from onnx.helper import np_dtype_to_tensor_dtype +from onnx.onnx_pb import TensorProto +from onnx.reference.op_run import OpRun +from onnx.reference.ops.op_cast import ( + bfloat16, + cast_to, + float8e4m3fn, + float8e4m3fnuz, + float8e5m2, + float8e5m2fnuz, +) + + +def _cast_like(x, y, saturate): + if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16": + # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16 + to = TensorProto.BFLOAT16 + elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn": + to = TensorProto.FLOAT8E4M3FN + elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz": + to = TensorProto.FLOAT8E4M3FNUZ + elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2": + to = TensorProto.FLOAT8E5M2 + elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz": + to = TensorProto.FLOAT8E5M2FNUZ + else: + to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore + return (cast_to(x, to, saturate),) + + +class CastLike_15(OpRun): + def _run(self, x, y): # type: ignore + return _cast_like(x, y, True) + + +class CastLike_19(OpRun): + def _run(self, x, y, saturate=None): # type: ignore + return _cast_like(x, y, saturate) diff --git a/pyproject.toml b/pyproject.toml index 60043b5..7e15de0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,5 +37,6 @@ max-complexity = 10 "onnx_array_api/npx/npx_tensors.py" = ["F821"] "onnx_array_api/npx/npx_var.py" = ["F821"] "onnx_array_api/profiling.py" = ["E731"] +"onnx_array_api/reference/__init__.py" = ["F401"] "_unittests/ut_npx/test_npx.py" = ["F821"] diff --git a/requirements-dev.txt b/requirements-dev.txt index 07fd7c3..4cc0562 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,7 +3,7 @@ black coverage flake8 furo -hypothesis +hypothesis<6.80.0 isort joblib lightgbm pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy