Skip to content

Commit 9de394e

Browse files
xadupresdpython
andauthored
Extends export onnx to code to support inner API (#47)
* Extend to use inner API * export subgraphs * update code * refactoring * add more tests * fix conversion * fix ut * fix ut * fix doc * doc * verbostiy * disable unstable test --------- Co-authored-by: Xavier Dupré <xavier.dupre@gmail.com>
1 parent 75d62a0 commit 9de394e

File tree

14 files changed

+1049
-129
lines changed

14 files changed

+1049
-129
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.1.3
55
+++++
66

7+
* :pr:`47`: extends export onnx to code to support inner API
78
* :pr:`46`: adds an export to convert an onnx graph into light API code
89
* :pr:`45`: fixes light API for operators with two outputs
910

_doc/api/light_api.rst

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,16 @@ Vars
4848
Classes for the Translater
4949
==========================
5050

51+
BaseEmitter
52+
+++++++++++
53+
54+
.. autoclass:: onnx_array_api.light_api.emitter.BaseEmitter
55+
:members:
56+
5157
Emitter
5258
+++++++
5359

54-
.. autoclass:: onnx_array_api.light_api.translate.Emitter
60+
.. autoclass:: onnx_array_api.light_api.emitter.Emitter
5561
:members:
5662

5763
EventType
@@ -60,6 +66,12 @@ EventType
6066
.. autoclass:: onnx_array_api.light_api.translate.EventType
6167
:members:
6268

69+
InnerEmitter
70+
++++++++++++
71+
72+
.. autoclass:: onnx_array_api.light_api.inner_emitter.InnerEmitter
73+
:members:
74+
6375
Translater
6476
++++++++++
6577

Binary file not shown.
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
import unittest
2+
from typing import Any, Dict, List, Optional
3+
from difflib import unified_diff
4+
import packaging.version as pv
5+
import numpy
6+
from numpy.testing import assert_allclose
7+
import onnx.backend.base
8+
import onnx.backend.test
9+
import onnx.shape_inference
10+
import onnx.version_converter
11+
from onnx import ModelProto, TensorProto, __version__ as onnx_version
12+
from onnx.helper import (
13+
make_function,
14+
make_graph,
15+
make_model,
16+
make_node,
17+
make_opsetid,
18+
make_tensor_value_info,
19+
)
20+
from onnx.numpy_helper import from_array, to_array
21+
from onnx.backend.base import Device, DeviceType
22+
from onnx_array_api.reference import ExtendedReferenceEvaluator
23+
from onnx_array_api.light_api import translate
24+
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
25+
26+
27+
class ReferenceImplementationError(RuntimeError):
28+
"Fails, export cannot be compared."
29+
pass
30+
31+
32+
class ExportWrapper:
33+
apis = ["onnx", "light"]
34+
35+
def __init__(self, model):
36+
self.model = model
37+
self.expected_sess = ExtendedReferenceEvaluator(self.model)
38+
39+
@property
40+
def input_names(self):
41+
return self.expected_sess.input_names
42+
43+
@property
44+
def input_types(self):
45+
return self.expected_sess.input_types
46+
47+
@property
48+
def output_names(self):
49+
return self.expected_sess.output_names
50+
51+
@property
52+
def output_types(self):
53+
return self.expected_sess.output_types
54+
55+
def run(
56+
self, names: Optional[List[str]], feeds: Optional[Dict[str, Any]] = None
57+
) -> List[Any]:
58+
try:
59+
expected = self.expected_sess.run(names, feeds)
60+
except (RuntimeError, AssertionError, TypeError, KeyError) as e:
61+
raise ReferenceImplementationError(
62+
f"ReferenceImplementation fails with {onnx_simple_text_plot(self.model)}"
63+
f"\n--RAW--\n{self.model}"
64+
) from e
65+
66+
for api in self.apis:
67+
try:
68+
code = translate(self.model, api=api)
69+
except NotImplementedError:
70+
continue
71+
except ValueError as e:
72+
raise AssertionError(
73+
f"Unable to translate model for api {api!r}, "
74+
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
75+
f"\n--EXPECTED--\n{expected}"
76+
) from e
77+
try:
78+
code_compiled = compile(code, "<string>", mode="exec")
79+
except Exception as e:
80+
new_code = "\n".join(
81+
[f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))]
82+
)
83+
raise AssertionError(f"ERROR {e}\n{new_code}")
84+
85+
locs = {
86+
"np": numpy,
87+
"to_array": to_array,
88+
"from_array": from_array,
89+
"TensorProto": TensorProto,
90+
"make_function": make_function,
91+
"make_opsetid": make_opsetid,
92+
"make_model": make_model,
93+
"make_graph": make_graph,
94+
"make_node": make_node,
95+
"make_tensor_value_info": make_tensor_value_info,
96+
}
97+
globs = locs.copy()
98+
try:
99+
exec(code_compiled, globs, locs)
100+
except (TypeError, NameError, ValueError) as e:
101+
new_code = "\n".join(
102+
[f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))]
103+
)
104+
raise AssertionError(
105+
f"Unable to executed code for api {api!r}\n{new_code}"
106+
) from e
107+
export_model = locs["model"]
108+
ref = ExtendedReferenceEvaluator(export_model)
109+
try:
110+
got = ref.run(names, feeds)
111+
except (TypeError, AttributeError) as e:
112+
diff = "\n".join(
113+
unified_diff(
114+
str(self.model).split("\n"),
115+
str(export_model).split("\n"),
116+
fromfile="before",
117+
tofile="after",
118+
)
119+
)
120+
raise AssertionError(
121+
f"Unable to run the exported model for api {api!r}, "
122+
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
123+
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
124+
f"\n--CODE--\n{code}"
125+
f"\n--FEEDS--\n{feeds}"
126+
f"\n--EXPECTED--\n{expected}"
127+
f"\n--DIFF--\n{diff}"
128+
) from e
129+
if len(expected) != len(got):
130+
raise AssertionError(
131+
f"Unexpected number of outputs for api {api!r}, "
132+
f"{len(expected)} != {len(got)}."
133+
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
134+
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
135+
)
136+
for a, b in zip(expected, got):
137+
if not isinstance(a, numpy.ndarray):
138+
continue
139+
if a.shape != b.shape or a.dtype != b.dtype:
140+
raise AssertionError(
141+
f"Shape or type discrepancies for api {api!r}."
142+
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
143+
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
144+
)
145+
if a.dtype in (numpy.str_, object, numpy.object_) or isinstance(
146+
a.dtype, getattr(getattr(numpy, "dtypes", None), "StrDType", type)
147+
):
148+
if a.tolist() != b.tolist():
149+
raise AssertionError(
150+
f"Text discrepancies for api {api!r} with a.dtype={a.dtype} "
151+
f"and b.dtype={b.dtype}"
152+
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
153+
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
154+
)
155+
continue
156+
try:
157+
assert_allclose(a, b, atol=1e-3)
158+
except (AssertionError, TypeError) as e:
159+
raise AssertionError(
160+
f"Discrepancies for api {api!r} with a.dtype={a.dtype} "
161+
f"and b.dtype={b.dtype} (type-dtype={type(a.dtype)})"
162+
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
163+
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"
164+
) from e
165+
166+
return expected
167+
168+
169+
class ExportBackendRep(onnx.backend.base.BackendRep):
170+
def __init__(self, session):
171+
self._session = session
172+
173+
def run(self, inputs, **kwargs):
174+
if isinstance(inputs, numpy.ndarray):
175+
inputs = [inputs]
176+
if isinstance(inputs, list):
177+
if len(inputs) == len(self._session.input_names):
178+
feeds = dict(zip(self._session.input_names, inputs))
179+
else:
180+
feeds = {}
181+
pos_inputs = 0
182+
for inp, tshape in zip(
183+
self._session.input_names, self._session.input_types
184+
):
185+
shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim)
186+
if shape == inputs[pos_inputs].shape:
187+
feeds[inp] = inputs[pos_inputs]
188+
pos_inputs += 1
189+
if pos_inputs >= len(inputs):
190+
break
191+
elif isinstance(inputs, dict):
192+
feeds = inputs
193+
else:
194+
raise TypeError(f"Unexpected input type {type(inputs)!r}.")
195+
outs = self._session.run(None, feeds)
196+
return outs
197+
198+
199+
class ExportBackend(onnx.backend.base.Backend):
200+
@classmethod
201+
def is_opset_supported(cls, model): # pylint: disable=unused-argument
202+
return True, ""
203+
204+
@classmethod
205+
def supports_device(cls, device: str) -> bool:
206+
d = Device(device)
207+
return d.type == DeviceType.CPU # type: ignore[no-any-return]
208+
209+
@classmethod
210+
def create_inference_session(cls, model):
211+
return ExportWrapper(model)
212+
213+
@classmethod
214+
def prepare(
215+
cls, model: Any, device: str = "CPU", **kwargs: Any
216+
) -> ExportBackendRep:
217+
if isinstance(model, ExportWrapper):
218+
return ExportBackendRep(model)
219+
if isinstance(model, (str, bytes, ModelProto)):
220+
inf = cls.create_inference_session(model)
221+
return cls.prepare(inf, device, **kwargs)
222+
raise TypeError(f"Unexpected type {type(model)} for model.")
223+
224+
@classmethod
225+
def run_model(cls, model, inputs, device=None, **kwargs):
226+
rep = cls.prepare(model, device, **kwargs)
227+
return rep.run(inputs, **kwargs)
228+
229+
@classmethod
230+
def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
231+
raise NotImplementedError("Unable to run the model node by node.")
232+
233+
234+
backend_test = onnx.backend.test.BackendTest(ExportBackend, __name__)
235+
236+
# The following tests are too slow with the reference implementation (Conv).
237+
backend_test.exclude(
238+
"(FLOAT8|BFLOAT16|_opt_|_3d_|_momentum_|_4d_"
239+
"|test_adagrad"
240+
"|test_adam"
241+
"|test_ai_onnx_ml_"
242+
"|test_cast_FLOAT16"
243+
"|test_cast_FLOAT_to_STRING"
244+
"|test_castlike_FLOAT16"
245+
"|test_castlike_FLOAT_to_STRING"
246+
"|test_bernoulli"
247+
"|test_bvlc_alexnet"
248+
"|test_conv" # too long
249+
"|test_gradient_"
250+
"|test_densenet121"
251+
"|test_inception_v1"
252+
"|test_inception_v2"
253+
"|test_loop11_"
254+
"|test_loop16_seq_none"
255+
"|test_MaxPool2d"
256+
"|test_quantizelinear_e"
257+
"|test_resnet50"
258+
"|test_sequence_model"
259+
"|test_scan_sum"
260+
"|test_scatter_with_axis"
261+
"|test_scatter_without_axis"
262+
"|test_shufflenet"
263+
"|test_squeezenet"
264+
"|test_vgg19"
265+
"|test_zfnet512"
266+
")"
267+
)
268+
269+
if pv.Version(onnx_version) < pv.Version("1.16.0"):
270+
backend_test.exclude("(test_strnorm|test_range_)")
271+
272+
# The following tests cannot pass because they consists in generating random number.
273+
backend_test.exclude("(test_bernoulli)")
274+
275+
# import all test cases at global scope to make them visible to python.unittest
276+
globals().update(backend_test.test_cases)
277+
278+
if __name__ == "__main__":
279+
res = unittest.main(verbosity=2, exit=False)
280+
tests_run = res.result.testsRun
281+
errors = len(res.result.errors)
282+
skipped = len(res.result.skipped)
283+
unexpected_successes = len(res.result.unexpectedSuccesses)
284+
expected_failures = len(res.result.expectedFailures)
285+
print("---------------------------------")
286+
print(
287+
f"tests_run={tests_run} errors={errors} skipped={skipped} "
288+
f"unexpected_successes={unexpected_successes} "
289+
f"expected_failures={expected_failures}"
290+
)

_unittests/ut_light_api/test_light_api.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import unittest
2+
import sys
23
from typing import Callable, Optional
34
import numpy as np
45
from onnx import ModelProto
@@ -144,6 +145,7 @@ def list_ops_missing(self, n_inputs):
144145
f"{new_missing}\n{text}"
145146
)
146147

148+
@unittest.skipIf(sys.platform == "win32", reason="unstable test on Windows")
147149
def test_list_ops_missing(self):
148150
self.list_ops_missing(1)
149151
self.list_ops_missing(2)

_unittests/ut_light_api/test_translate.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,17 @@
66
from onnx.reference import ReferenceEvaluator
77
from onnx_array_api.ext_test_case import ExtTestCase
88
from onnx_array_api.light_api import start, translate
9+
from onnx_array_api.light_api.emitter import EventType
910

1011
OPSET_API = min(19, onnx_opset_version() - 1)
1112

1213

1314
class TestTranslate(ExtTestCase):
15+
def test_event_type(self):
16+
self.assertEqual(
17+
EventType.to_str(EventType.INITIALIZER), "EventType.INITIALIZER"
18+
)
19+
1420
def test_exp(self):
1521
onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
1622
self.assertIsInstance(onx, ModelProto)
@@ -73,6 +79,8 @@ def test_transpose(self):
7379
"""
7480
(
7581
start(opset=19)
82+
.cst(np.array([-1, 1], dtype=np.int64))
83+
.rename('r')
7684
.vin('X', elem_type=TensorProto.FLOAT)
7785
.bring('X', 'r')
7886
.Reshape()

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy