Skip to content

Commit 61eec9d

Browse files
authored
Add ExtendedReferenceEvaluator to test scenario outside onnx specifications (#24)
* Add ExtendedReferenceEvaluator to test scenario outside onnx specifications * lint * update doc * req
1 parent 02ee072 commit 61eec9d

File tree

13 files changed

+385
-13
lines changed

13 files changed

+385
-13
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`24`: add ExtendedReferenceEvaluator to support scenario for the Array API onnx does not support
78
* :pr:`22`: support OrtValue in function :func:`ort_profile`
89
* :pr:`17`: implements ArrayAPI
910
* :pr:`3`: fixes Array API with onnxruntime and scikit-learn

_doc/api/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ API
1515
onnx_tools
1616
ort
1717
plotting
18+
reference
1819
tools

_doc/api/reference.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
reference
2+
=========
3+
4+
ExtendedReferenceEvaluator
5+
++++++++++++++++++++++++++
6+
7+
.. autoclass:: onnx_array_api.reference.ExtendedReferenceEvaluator

_unittests/onnx-numpy-skips.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,4 @@ array_api_tests/test_creation_functions.py::test_eye
99
array_api_tests/test_creation_functions.py::test_full_like
1010
array_api_tests/test_creation_functions.py::test_linspace
1111
array_api_tests/test_creation_functions.py::test_meshgrid
12-
# Issue with CastLike and bfloat16 on onnx <= 1.15.0
13-
# array_api_tests/test_creation_functions.py::test_ones_like
1412
array_api_tests/test_creation_functions.py::test_zeros_like

_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import sys
22
import unittest
3-
from packaging.version import Version
43
import numpy as np
5-
from onnx import TensorProto, __version__ as onnx_ver
4+
from onnx import TensorProto
65
from onnx_array_api.ext_test_case import ExtTestCase
76
from onnx_array_api.array_api import onnx_numpy as xp
87
from onnx_array_api.npx.npx_types import DType
@@ -99,10 +98,6 @@ def test_arange_int00(self):
9998
expected = expected.astype(np.int64)
10099
self.assertEqualArray(matnp, expected)
101100

102-
@unittest.skipIf(
103-
Version(onnx_ver) < Version("1.15.0"),
104-
reason="Reference implementation of CastLike is bugged.",
105-
)
106101
def test_ones_like_uint16(self):
107102
x = EagerTensor(np.array(0, dtype=np.uint16))
108103
y = np.ones_like(x.numpy())
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
import os
2+
import platform
3+
import unittest
4+
from typing import Any
5+
import numpy
6+
import onnx.backend.base
7+
import onnx.backend.test
8+
import onnx.shape_inference
9+
import onnx.version_converter
10+
from onnx import ModelProto
11+
from onnx.backend.base import Device, DeviceType
12+
from onnx.defs import onnx_opset_version
13+
from onnx_array_api.reference import ExtendedReferenceEvaluator
14+
15+
16+
class ExtendedReferenceEvaluatorBackendRep(onnx.backend.base.BackendRep):
17+
def __init__(self, session):
18+
self._session = session
19+
20+
def run(self, inputs, **kwargs):
21+
if isinstance(inputs, numpy.ndarray):
22+
inputs = [inputs]
23+
if isinstance(inputs, list):
24+
if len(inputs) == len(self._session.input_names):
25+
feeds = dict(zip(self._session.input_names, inputs))
26+
else:
27+
feeds = {}
28+
pos_inputs = 0
29+
for inp, tshape in zip(
30+
self._session.input_names, self._session.input_types
31+
):
32+
shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim)
33+
if shape == inputs[pos_inputs].shape:
34+
feeds[inp] = inputs[pos_inputs]
35+
pos_inputs += 1
36+
if pos_inputs >= len(inputs):
37+
break
38+
elif isinstance(inputs, dict):
39+
feeds = inputs
40+
else:
41+
raise TypeError(f"Unexpected input type {type(inputs)!r}.")
42+
outs = self._session.run(None, feeds)
43+
return outs
44+
45+
46+
class ExtendedReferenceEvaluatorBackend(onnx.backend.base.Backend):
47+
@classmethod
48+
def is_opset_supported(cls, model): # pylint: disable=unused-argument
49+
return True, ""
50+
51+
@classmethod
52+
def supports_device(cls, device: str) -> bool:
53+
d = Device(device)
54+
return d.type == DeviceType.CPU # type: ignore[no-any-return]
55+
56+
@classmethod
57+
def create_inference_session(cls, model):
58+
return ExtendedReferenceEvaluator(model)
59+
60+
@classmethod
61+
def prepare(
62+
cls, model: Any, device: str = "CPU", **kwargs: Any
63+
) -> ExtendedReferenceEvaluatorBackendRep:
64+
# if isinstance(model, ExtendedReferenceEvaluatorBackendRep):
65+
# return model
66+
if isinstance(model, ExtendedReferenceEvaluator):
67+
return ExtendedReferenceEvaluatorBackendRep(model)
68+
if isinstance(model, (str, bytes, ModelProto)):
69+
inf = cls.create_inference_session(model)
70+
return cls.prepare(inf, device, **kwargs)
71+
raise TypeError(f"Unexpected type {type(model)} for model.")
72+
73+
@classmethod
74+
def run_model(cls, model, inputs, device=None, **kwargs):
75+
rep = cls.prepare(model, device, **kwargs)
76+
return rep.run(inputs, **kwargs)
77+
78+
@classmethod
79+
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
80+
raise NotImplementedError("Unable to run the model node by node.")
81+
82+
83+
backend_test = onnx.backend.test.BackendTest(
84+
ExtendedReferenceEvaluatorBackend, __name__
85+
)
86+
87+
if os.getenv("APPVEYOR"):
88+
backend_test.exclude("(test_vgg19|test_zfnet)")
89+
if platform.architecture()[0] == "32bit":
90+
backend_test.exclude("(test_vgg19|test_zfnet|test_bvlc_alexnet)")
91+
if platform.system() == "Windows":
92+
backend_test.exclude("test_sequence_model")
93+
94+
if onnx_opset_version() < 21:
95+
backend_test.exclude(
96+
"(test_averagepool_2d_dilations"
97+
"|test_if*"
98+
"|test_loop*"
99+
"|test_scan*"
100+
"|test_sequence_map*"
101+
")"
102+
)
103+
104+
if onnx_opset_version() < 19:
105+
backend_test.exclude(
106+
"(test_argm[ai][nx]_default_axis_example"
107+
"|test_argm[ai][nx]_default_axis_random"
108+
"|test_argm[ai][nx]_keepdims_example"
109+
"|test_argm[ai][nx]_keepdims_random"
110+
"|test_argm[ai][nx]_negative_axis_keepdims_example"
111+
"|test_argm[ai][nx]_negative_axis_keepdims_random"
112+
"|test_argm[ai][nx]_no_keepdims_example"
113+
"|test_argm[ai][nx]_no_keepdims_random"
114+
"|test_col2im_pads"
115+
"|test_gru_batchwise"
116+
"|test_gru_defaults"
117+
"|test_gru_seq_length"
118+
"|test_gru_with_initial_bias"
119+
"|test_layer_normalization_2d_axis1_expanded"
120+
"|test_layer_normalization_2d_axis_negative_1_expanded"
121+
"|test_layer_normalization_3d_axis1_epsilon_expanded"
122+
"|test_layer_normalization_3d_axis2_epsilon_expanded"
123+
"|test_layer_normalization_3d_axis_negative_1_epsilon_expanded"
124+
"|test_layer_normalization_3d_axis_negative_2_epsilon_expanded"
125+
"|test_layer_normalization_4d_axis1_expanded"
126+
"|test_layer_normalization_4d_axis2_expanded"
127+
"|test_layer_normalization_4d_axis3_expanded"
128+
"|test_layer_normalization_4d_axis_negative_1_expanded"
129+
"|test_layer_normalization_4d_axis_negative_2_expanded"
130+
"|test_layer_normalization_4d_axis_negative_3_expanded"
131+
"|test_layer_normalization_default_axis_expanded"
132+
"|test_logsoftmax_large_number_expanded"
133+
"|test_lstm_batchwise"
134+
"|test_lstm_defaults"
135+
"|test_lstm_with_initial_bias"
136+
"|test_lstm_with_peepholes"
137+
"|test_mvn"
138+
"|test_mvn_expanded"
139+
"|test_softmax_large_number_expanded"
140+
"|test_operator_reduced_mean"
141+
"|test_operator_reduced_mean_keepdim)"
142+
)
143+
144+
# The following tests are not supported.
145+
backend_test.exclude(
146+
"(test_gradient"
147+
"|test_if_opt"
148+
"|test_loop16_seq_none"
149+
"|test_range_float_type_positive_delta_expanded"
150+
"|test_range_int32_type_negative_delta_expanded"
151+
"|test_scan_sum)"
152+
)
153+
154+
if onnx_opset_version() < 21:
155+
# The following tests are using types not supported by NumPy.
156+
# They could be if method to_array is extended to support custom
157+
# types the same as the reference implementation does
158+
# (see onnx.reference.op_run.to_array_extended).
159+
backend_test.exclude(
160+
"(test_cast_FLOAT_to_BFLOAT16"
161+
"|test_cast_BFLOAT16_to_FLOAT"
162+
"|test_cast_BFLOAT16_to_FLOAT"
163+
"|test_castlike_BFLOAT16_to_FLOAT"
164+
"|test_castlike_FLOAT_to_BFLOAT16"
165+
"|test_castlike_FLOAT_to_BFLOAT16_expanded"
166+
"|test_cast_no_saturate_"
167+
"|_to_FLOAT8"
168+
"|_FLOAT8"
169+
"|test_quantizelinear_e4m3fn"
170+
"|test_quantizelinear_e5m2"
171+
")"
172+
)
173+
174+
# Disable test about float 8
175+
backend_test.exclude(
176+
"(test_castlike_BFLOAT16*"
177+
"|test_cast_BFLOAT16*"
178+
"|test_cast_no_saturate*"
179+
"|test_cast_FLOAT_to_FLOAT8*"
180+
"|test_cast_FLOAT16_to_FLOAT8*"
181+
"|test_cast_FLOAT8_to_*"
182+
"|test_castlike_BFLOAT16*"
183+
"|test_castlike_no_saturate*"
184+
"|test_castlike_FLOAT_to_FLOAT8*"
185+
"|test_castlike_FLOAT16_to_FLOAT8*"
186+
"|test_castlike_FLOAT8_to_*"
187+
"|test_quantizelinear_e*)"
188+
)
189+
190+
# The following tests are too slow with the reference implementation (Conv).
191+
backend_test.exclude(
192+
"(test_bvlc_alexnet"
193+
"|test_densenet121"
194+
"|test_inception_v1"
195+
"|test_inception_v2"
196+
"|test_resnet50"
197+
"|test_shufflenet"
198+
"|test_squeezenet"
199+
"|test_vgg19"
200+
"|test_zfnet512)"
201+
)
202+
203+
# The following tests cannot pass because they consists in generating random number.
204+
backend_test.exclude("(test_bernoulli)")
205+
206+
if onnx_opset_version() < 21:
207+
# The following tests fail due to a bug in the backend test comparison.
208+
backend_test.exclude(
209+
"(test_cast_FLOAT_to_STRING|test_castlike_FLOAT_to_STRING|test_strnorm)"
210+
)
211+
212+
# The following tests fail due to a shape mismatch.
213+
backend_test.exclude(
214+
"(test_center_crop_pad_crop_axes_hwc_expanded|test_lppool_2d_dilations)"
215+
)
216+
217+
# The following tests fail due to a type mismatch.
218+
backend_test.exclude("(test_eyelike_without_dtype)")
219+
220+
# The following tests fail due to discrepancies (small but still higher than 1e-7).
221+
backend_test.exclude("test_adam_multiple") # 1e-2
222+
223+
224+
# import all test cases at global scope to make them visible to python.unittest
225+
globals().update(backend_test.test_cases)
226+
227+
if __name__ == "__main__":
228+
res = unittest.main(verbosity=2, exit=False)
229+
tests_run = res.result.testsRun
230+
errors = len(res.result.errors)
231+
skipped = len(res.result.skipped)
232+
unexpected_successes = len(res.result.unexpectedSuccesses)
233+
expected_failures = len(res.result.expectedFailures)
234+
print("---------------------------------")
235+
print(
236+
f"tests_run={tests_run} errors={errors} skipped={skipped} "
237+
f"unexpected_successes={unexpected_successes} "
238+
f"expected_failures={expected_failures}"
239+
)

onnx_array_api/npx/npx_numpy_tensors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Callable, List, Optional, Tuple
22
import numpy as np
33
from onnx import ModelProto, TensorProto
4-
from onnx.reference import ReferenceEvaluator
4+
from ..reference import ExtendedReferenceEvaluator
55
from .._helpers import np_dtype_to_tensor_dtype
66
from .npx_numpy_tensors_ops import ConstantOfShape
77
from .npx_tensors import EagerTensor, JitTensor
@@ -11,15 +11,15 @@
1111
class NumpyTensor:
1212
"""
1313
Default backend based on
14-
:func:`onnx.reference.ReferenceEvaluator`.
14+
:func:`onnx_array_api.reference.ExtendedReferenceEvaluator`.
1515
1616
:param input_names: input names
1717
:param onx: onnx model
1818
"""
1919

2020
class Evaluator:
2121
"""
22-
Wraps class :class:`onnx.reference.ReferenceEvaluator`
22+
Wraps class :class:`onnx_array_api.reference.ExtendedReferenceEvaluator`
2323
to have a signature closer to python function.
2424
2525
:param tensor_class: class tensor such as :class:`NumpyTensor`
@@ -35,7 +35,7 @@ def __init__(
3535
onx: ModelProto,
3636
f: Callable,
3737
):
38-
self.ref = ReferenceEvaluator(onx, new_ops=[ConstantOfShape])
38+
self.ref = ExtendedReferenceEvaluator(onx, new_ops=[ConstantOfShape])
3939
self.input_names = input_names
4040
self.tensor_class = tensor_class
4141
self._f = f

onnx_array_api/reference/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .evaluator import ExtendedReferenceEvaluator

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy