Skip to content

Commit 75d62a0

Browse files
authored
Add an export to convert an onnx graph into light API code (#46)
* Add an export to convert an onnx graph into light API code * fix unit tests * fix annotations * fix documentation * doc
1 parent dd11424 commit 75d62a0

File tree

9 files changed

+489
-11
lines changed

9 files changed

+489
-11
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.1.3
55
+++++
66

7+
* :pr:`46`: adds an export to convert an onnx graph into light API code
78
* :pr:`45`: fixes light API for operators with two outputs
89

910
0.1.2

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,4 +141,4 @@ The euclidean distance looks like the following:
141141
The library is released on
142142
`pypi/onnx-array-api <https://pypi.org/project/onnx-array-api/>`_
143143
and its documentation is published at
144-
`(Numpy) Array API for ONNX <https://sdpython.github.io/doc/onnx-array-api/dev/>`_.
144+
`APIs to create ONNX Graphs <https://sdpython.github.io/doc/onnx-array-api/dev/>`_.

_doc/api/light_api.rst

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,67 @@
22
onnx_array_api.light_api
33
========================
44

5+
6+
Main API
7+
========
8+
59
start
6-
=====
10+
+++++
711

812
.. autofunction:: onnx_array_api.light_api.start
913

14+
translate
15+
+++++++++
16+
17+
.. autofunction:: onnx_array_api.light_api.translate
18+
19+
Classes for the Light API
20+
=========================
21+
1022
OnnxGraph
11-
=========
23+
+++++++++
1224

1325
.. autoclass:: onnx_array_api.light_api.OnnxGraph
1426
:members:
1527

1628
BaseVar
17-
=======
29+
+++++++
1830

1931
.. autoclass:: onnx_array_api.light_api.var.BaseVar
2032
:members:
2133

2234
Var
23-
===
35+
+++
2436

2537
.. autoclass:: onnx_array_api.light_api.Var
2638
:members:
2739
:inherited-members:
2840

2941
Vars
30-
====
42+
++++
3143

3244
.. autoclass:: onnx_array_api.light_api.Vars
3345
:members:
3446
:inherited-members:
47+
48+
Classes for the Translater
49+
==========================
50+
51+
Emitter
52+
+++++++
53+
54+
.. autoclass:: onnx_array_api.light_api.translate.Emitter
55+
:members:
56+
57+
EventType
58+
+++++++++
59+
60+
.. autoclass:: onnx_array_api.light_api.translate.EventType
61+
:members:
62+
63+
Translater
64+
++++++++++
65+
66+
.. autoclass:: onnx_array_api.light_api.translate.Translater
67+
:members:
68+

_doc/index.rst

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ The objective is to speed up the implementation of converter libraries.
4545
CHANGELOGS
4646
license
4747

48-
**Numpy API**
48+
Numpy API
49+
+++++++++
4950

5051
Sources available on
5152
`github/onnx-array-api <https://github.com/sdpython/onnx-array-api>`_.
@@ -109,7 +110,8 @@ Sources available on
109110
res = jitted_myloss(x, y)
110111
print(to_dot(jitted_myloss.get_onnx()))
111112

112-
**Light API**
113+
Light API
114+
+++++++++
113115

114116
.. runpython::
115117
:showcode:
@@ -135,3 +137,9 @@ Sources available on
135137
)
136138

137139
print(onnx_simple_text_plot(model))
140+
141+
142+
Older versions
143+
++++++++++++++
144+
145+
* `0.1.2 <../v0.1.2/index.html>`_
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import unittest
2+
from textwrap import dedent
3+
import numpy as np
4+
from onnx import ModelProto, TensorProto
5+
from onnx.defs import onnx_opset_version
6+
from onnx.reference import ReferenceEvaluator
7+
from onnx_array_api.ext_test_case import ExtTestCase
8+
from onnx_array_api.light_api import start, translate
9+
10+
OPSET_API = min(19, onnx_opset_version() - 1)
11+
12+
13+
class TestTranslate(ExtTestCase):
14+
def test_exp(self):
15+
onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
16+
self.assertIsInstance(onx, ModelProto)
17+
self.assertIn("Exp", str(onx))
18+
ref = ReferenceEvaluator(onx)
19+
a = np.arange(10).astype(np.float32)
20+
got = ref.run(None, {"X": a})[0]
21+
self.assertEqualArray(np.exp(a), got)
22+
23+
code = translate(onx)
24+
expected = dedent(
25+
"""
26+
(
27+
start(opset=19)
28+
.vin('X', elem_type=TensorProto.FLOAT)
29+
.bring('X')
30+
.Exp()
31+
.rename('Y')
32+
.bring('Y')
33+
.vout(elem_type=TensorProto.FLOAT)
34+
.to_onnx()
35+
)"""
36+
).strip("\n")
37+
self.assertEqual(expected, code)
38+
39+
onx2 = (
40+
start(opset=19)
41+
.vin("X", elem_type=TensorProto.FLOAT)
42+
.bring("X")
43+
.Exp()
44+
.rename("Y")
45+
.bring("Y")
46+
.vout(elem_type=TensorProto.FLOAT)
47+
.to_onnx()
48+
)
49+
ref = ReferenceEvaluator(onx2)
50+
a = np.arange(10).astype(np.float32)
51+
got = ref.run(None, {"X": a})[0]
52+
self.assertEqualArray(np.exp(a), got)
53+
54+
def test_transpose(self):
55+
onx = (
56+
start(opset=19)
57+
.vin("X")
58+
.reshape((-1, 1))
59+
.Transpose(perm=[1, 0])
60+
.rename("Y")
61+
.vout()
62+
.to_onnx()
63+
)
64+
self.assertIsInstance(onx, ModelProto)
65+
self.assertIn("Transpose", str(onx))
66+
ref = ReferenceEvaluator(onx)
67+
a = np.arange(10).astype(np.float32)
68+
got = ref.run(None, {"X": a})[0]
69+
self.assertEqualArray(a.reshape((-1, 1)).T, got)
70+
71+
code = translate(onx)
72+
expected = dedent(
73+
"""
74+
(
75+
start(opset=19)
76+
.vin('X', elem_type=TensorProto.FLOAT)
77+
.bring('X', 'r')
78+
.Reshape()
79+
.rename('r0_0')
80+
.bring('r0_0')
81+
.Transpose(perm=[1, 0])
82+
.rename('Y')
83+
.bring('Y')
84+
.vout(elem_type=TensorProto.FLOAT)
85+
.to_onnx()
86+
)"""
87+
).strip("\n")
88+
self.assertEqual(expected, code)
89+
90+
def test_topk_reverse(self):
91+
onx = (
92+
start(opset=19)
93+
.vin("X", np.float32)
94+
.vin("K", np.int64)
95+
.bring("X", "K")
96+
.TopK(largest=0)
97+
.rename("Values", "Indices")
98+
.vout()
99+
.to_onnx()
100+
)
101+
self.assertIsInstance(onx, ModelProto)
102+
ref = ReferenceEvaluator(onx)
103+
x = np.array([[0, 1, 2, 3], [9, 8, 7, 6]], dtype=np.float32)
104+
k = np.array([2], dtype=np.int64)
105+
got = ref.run(None, {"X": x, "K": k})
106+
self.assertEqualArray(np.array([[0, 1], [6, 7]], dtype=np.float32), got[0])
107+
self.assertEqualArray(np.array([[0, 1], [3, 2]], dtype=np.int64), got[1])
108+
109+
code = translate(onx)
110+
expected = dedent(
111+
"""
112+
(
113+
start(opset=19)
114+
.vin('X', elem_type=TensorProto.FLOAT)
115+
.vin('K', elem_type=TensorProto.INT64)
116+
.bring('X', 'K')
117+
.TopK(axis=-1, largest=0, sorted=1)
118+
.rename('Values', 'Indices')
119+
.bring('Values')
120+
.vout(elem_type=TensorProto.FLOAT)
121+
.bring('Indices')
122+
.vout(elem_type=TensorProto.FLOAT)
123+
.to_onnx()
124+
)"""
125+
).strip("\n")
126+
self.assertEqual(expected, code)
127+
128+
129+
if __name__ == "__main__":
130+
# TestLightApi().test_topk()
131+
unittest.main(verbosity=2)

onnx_array_api/light_api/__init__.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Dict, Optional
2+
from onnx import ModelProto
23
from .model import OnnxGraph
4+
from .translate import Translater
35
from .var import Var, Vars
46

57

@@ -34,8 +36,48 @@ def start(
3436
from onnx_array_api.light_api import start
3537
3638
onx = (
37-
start().vin("X").vin("Y").bring("X", "Y").Add().rename("Z").vout().to_onnx()
39+
start()
40+
.vin("X")
41+
.vin("Y")
42+
.bring("X", "Y")
43+
.Add()
44+
.rename("Z")
45+
.vout()
46+
.to_onnx()
3847
)
3948
print(onx)
4049
"""
4150
return OnnxGraph(opset=opset, opsets=opsets, is_function=is_function)
51+
52+
53+
def translate(proto: ModelProto, single_line=False) -> str:
54+
"""
55+
Translates an ONNX proto into a code using :ref:`l-light-api`
56+
to describe the ONNX graph.
57+
58+
:param proto: model to translate
59+
:param single_line: as a single line or not
60+
:return: code
61+
62+
.. runpython::
63+
:showcode:
64+
65+
from onnx_array_api.light_api import start, translate
66+
67+
onx = (
68+
start()
69+
.vin("X")
70+
.reshape((-1, 1))
71+
.Transpose(perm=[1, 0])
72+
.rename("Y")
73+
.vout()
74+
.to_onnx()
75+
)
76+
code = translate(onx)
77+
print(code)
78+
"""
79+
tr = Translater(proto)
80+
rows = tr.export()
81+
if single_line:
82+
return ".".join(rows)
83+
return "".join(["(\n ", "\n .".join(rows), "\n)"])

onnx_array_api/light_api/annotations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
ELEMENT_TYPE_NAME = {
1313
getattr(TensorProto, k): k
1414
for k in dir(TensorProto)
15-
if isinstance(getattr(TensorProto, k), int)
15+
if isinstance(getattr(TensorProto, k), int) and "_" not in k
1616
}
1717

1818
_type_numpy = {

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy