Skip to content

Commit 7895c27

Browse files
authored
Support translation of local functions (#60)
* add function to translate functions * doc * fix translation of local functions * refactoring * fix missing import * verbose * link
1 parent 71aa3a0 commit 7895c27

File tree

12 files changed

+492
-147
lines changed

12 files changed

+492
-147
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`60`: supports translation of local functions
78
* :pr:`59`: add methods to update nodes in GraphAPI
89

910
0.1.3

_doc/api/light_api.rst

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ translate
1616

1717
.. autofunction:: onnx_array_api.light_api.translate
1818

19+
make_helper
20+
+++++++++++
21+
22+
.. autofunction:: onnx_array_api.light_api.make_helper.make_node_extended
23+
24+
.. autofunction:: onnx_array_api.light_api.make_helper.make_ref_attribute
25+
1926
Classes for the Light API
2027
=========================
2128

@@ -68,19 +75,13 @@ Classes for the Translater
6875
BaseEmitter
6976
+++++++++++
7077

71-
.. autoclass:: onnx_array_api.light_api.emitter.BaseEmitter
72-
:members:
73-
74-
Emitter
75-
+++++++
76-
77-
.. autoclass:: onnx_array_api.light_api.emitter.Emitter
78+
.. autoclass:: onnx_array_api.light_api.base_emitter.BaseEmitter
7879
:members:
7980

8081
EventType
8182
+++++++++
8283

83-
.. autoclass:: onnx_array_api.light_api.translate.EventType
84+
.. autoclass:: onnx_array_api.light_api.base_emitter.EventType
8485
:members:
8586

8687
InnerEmitter
@@ -89,6 +90,12 @@ InnerEmitter
8990
.. autoclass:: onnx_array_api.light_api.inner_emitter.InnerEmitter
9091
:members:
9192

93+
LightEmitter
94+
++++++++++++
95+
96+
.. autoclass:: onnx_array_api.light_api.light_emitter.LightEmitter
97+
:members:
98+
9299
Translater
93100
++++++++++
94101

Binary file not shown.

_unittests/ut_light_api/test_backend_export.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import unittest
23
from typing import Any, Dict, List, Optional
34
from difflib import unified_diff
@@ -17,12 +18,16 @@
1718
make_opsetid,
1819
make_tensor_value_info,
1920
)
21+
from onnx.reference.op_run import to_array_extended
2022
from onnx.numpy_helper import from_array, to_array
2123
from onnx.backend.base import Device, DeviceType
2224
from onnx_array_api.reference import ExtendedReferenceEvaluator
25+
from onnx_array_api.light_api.make_helper import make_node_extended
2326
from onnx_array_api.light_api import translate
2427
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
2528

29+
verbosity = 10 if "-v" in sys.argv or "--verbose" in sys.argv else 0
30+
2631

2732
class ReferenceImplementationError(RuntimeError):
2833
"Fails, export cannot be compared."
@@ -34,7 +39,7 @@ class ExportWrapper:
3439

3540
def __init__(self, model):
3641
self.model = model
37-
self.expected_sess = ExtendedReferenceEvaluator(self.model)
42+
self.expected_sess = ExtendedReferenceEvaluator(self.model, verbose=verbosity)
3843

3944
@property
4045
def input_names(self):
@@ -85,13 +90,15 @@ def run(
8590
locs = {
8691
"np": numpy,
8792
"to_array": to_array,
93+
"to_array_extended": to_array_extended,
8894
"from_array": from_array,
8995
"TensorProto": TensorProto,
9096
"make_function": make_function,
9197
"make_opsetid": make_opsetid,
9298
"make_model": make_model,
9399
"make_graph": make_graph,
94100
"make_node": make_node,
101+
"make_node_extended": make_node_extended,
95102
"make_tensor_value_info": make_tensor_value_info,
96103
}
97104
globs = locs.copy()
@@ -105,7 +112,7 @@ def run(
105112
f"Unable to executed code for api {api!r}\n{new_code}"
106113
) from e
107114
export_model = locs["model"]
108-
ref = ExtendedReferenceEvaluator(export_model)
115+
ref = ExtendedReferenceEvaluator(export_model, verbose=verbosity)
109116
try:
110117
got = ref.run(names, feeds)
111118
except (TypeError, AttributeError) as e:

_unittests/ut_light_api/test_translate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from onnx.reference import ReferenceEvaluator
77
from onnx_array_api.ext_test_case import ExtTestCase
88
from onnx_array_api.light_api import start, translate, g
9-
from onnx_array_api.light_api.emitter import EventType
9+
from onnx_array_api.light_api.base_emitter import EventType
1010

1111
OPSET_API = min(19, onnx_opset_version() - 1)
1212

_unittests/ut_light_api/test_translate_classic.py

Lines changed: 116 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from onnx import ModelProto, TensorProto, load
66
from onnx.defs import onnx_opset_version
77
from onnx.reference import ReferenceEvaluator
8+
from onnx.reference.op_run import OpRun
89
from onnx.helper import (
910
make_tensor_value_info,
1011
make_node,
@@ -68,7 +69,7 @@ def test_exp(self):
6869
functions = []
6970
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
7071
nodes.append(
71-
make_node(
72+
make_node_extended(
7273
'Exp',
7374
['X'],
7475
['Y']
@@ -144,14 +145,14 @@ def test_transpose(self):
144145
)
145146
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
146147
nodes.append(
147-
make_node(
148+
make_node_extended(
148149
'Reshape',
149150
['X', 'r'],
150151
['r0_0']
151152
)
152153
)
153154
nodes.append(
154-
make_node(
155+
make_node_extended(
155156
'Transpose',
156157
['r0_0'],
157158
['Y'],
@@ -210,7 +211,7 @@ def test_topk_reverse(self):
210211
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
211212
inputs.append(make_tensor_value_info('K', TensorProto.INT64, shape=[]))
212213
nodes.append(
213-
make_node(
214+
make_node_extended(
214215
'TopK',
215216
['X', 'K'],
216217
['Values', 'Indices'],
@@ -264,7 +265,6 @@ def test_aionnxml(self):
264265
.to_onnx()
265266
)
266267
code = translate(onx, api="onnx")
267-
print(code)
268268
expected = dedent(
269269
"""
270270
opset_imports = [
@@ -285,14 +285,14 @@ def test_aionnxml(self):
285285
)
286286
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
287287
nodes.append(
288-
make_node(
288+
make_node_extended(
289289
'Reshape',
290290
['X', 'r'],
291291
['USE']
292292
)
293293
)
294294
nodes.append(
295-
make_node(
295+
make_node_extended(
296296
'Normalizer',
297297
['USE'],
298298
['Y'],
@@ -318,7 +318,115 @@ def test_aionnxml(self):
318318
self.maxDiff = None
319319
self.assertEqual(expected, code)
320320

321+
@classmethod
322+
def _code_line(cls, code):
323+
lines = code.split("\n")
324+
return "\n".join(f"{i+1:03d} {line}" for i, line in enumerate(lines))
325+
326+
@classmethod
327+
def _run(cls, code):
328+
try:
329+
code_compiled = compile(code, "<string>", mode="exec")
330+
except Exception as e:
331+
raise AssertionError(
332+
f"Compilation failed due to {e}\n---\n{cls._code_line(code)}\n---\n{e}"
333+
) from e
334+
335+
import onnx
336+
import onnx.helper
337+
import onnx.numpy_helper
338+
import onnx_array_api.light_api.make_helper
339+
import onnx.reference.custom_element_types
340+
341+
def from_array_extended(tensor, name=None):
342+
dt = tensor.dtype
343+
if (
344+
dt == onnx.reference.custom_element_types.float8e4m3fn
345+
and dt.descr[0][0] == "e4m3fn"
346+
):
347+
to = TensorProto.FLOAT8E4M3FN
348+
dt_to = np.uint8
349+
elif (
350+
dt == onnx.reference.custom_element_types.bfloat16
351+
and dt.descr[0][0] == "bfloat16"
352+
):
353+
to = TensorProto.BFLOAT16
354+
dt_to = np.uint16
355+
else:
356+
return onnx.numpy_helper.from_array(tensor, name)
357+
358+
t = onnx.numpy_helper.from_array(tensor.astype(dt_to), name)
359+
t.data_type = to
360+
return t
361+
362+
globs = onnx.__dict__.copy()
363+
globs.update(onnx.helper.__dict__)
364+
globs.update(onnx.numpy_helper.__dict__)
365+
globs.update(onnx_array_api.light_api.make_helper.__dict__)
366+
globs.update(onnx.reference.custom_element_types.__dict__)
367+
globs["from_array_extended"] = from_array_extended
368+
locs = {}
369+
try:
370+
exec(code_compiled, globs, locs)
371+
except Exception as e:
372+
raise AssertionError(
373+
f"Execution failed due to {e}\n---\n{cls._code_line(code)}\n---\n{e}"
374+
) from e
375+
return globs, locs
376+
377+
def test_remove_nodes(self):
378+
path = os.path.join(
379+
os.path.dirname(__file__), "_data", "custom_ops_type_inference_fails_0.onnx"
380+
)
381+
onx = load(path)
382+
code = translate(onx, api="onnx")
383+
_, locs = self._run(code)
384+
self.assertIn("model", locs)
385+
model = locs["model"]
386+
x = np.arange(4).reshape((-1, 2)).astype(np.float32)
387+
feeds = {"X": x}
388+
389+
class CustomGemmFloat8E4M3FN(OpRun):
390+
op_domain = "onnx_extented.ortops.tutorial.cpu"
391+
392+
def _run(
393+
self,
394+
x,
395+
y,
396+
bias=None,
397+
scale_x=None,
398+
scale_y=None,
399+
scale_z=None,
400+
transA=False,
401+
transB=False,
402+
dtype=None,
403+
rowMajor=None,
404+
computeType=None,
405+
):
406+
if scale_x is not None:
407+
x = x * scale_x
408+
if transA:
409+
x = x.T
410+
if scale_y is not None:
411+
y = y * scale_y
412+
if transB:
413+
y = y.T
414+
z = x @ y
415+
if bias is not None:
416+
z += bias
417+
if scale_z is not None:
418+
z = z / scale_z
419+
return (z,)
420+
421+
ref = ReferenceEvaluator(onx, new_ops=[CustomGemmFloat8E4M3FN])
422+
expected = ref.run(None, feeds)[0]
423+
ref2 = ReferenceEvaluator(model, new_ops=[CustomGemmFloat8E4M3FN])
424+
got = ref2.run(None, feeds)[0]
425+
self.assertEqualArray(expected, got)
426+
427+
# with open("debug_test_remove_nodes.py", "w") as f:
428+
# f.write(code)
429+
321430

322431
if __name__ == "__main__":
323-
# TestLightApi().test_topk()
324432
unittest.main(verbosity=2)

onnx_array_api/light_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def translate(proto: ModelProto, single_line: bool = False, api: str = "light")
6767
:param single_line: as a single line or not
6868
:param api: API to export into,
6969
default is `"light"` and this is handle by class
70-
:class:`onnx_array_api.light_api.emitter.Emitter`,
70+
:class:`onnx_array_api.light_api.light_emitter.LightEmitter`,
7171
another value is `"onnx"` which is the inner API implemented
7272
in onnx package.
7373
:return: code

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy