Skip to content

Commit 664e084

Browse files
authored
Improves translation to GraphBuilder (#95)
* Improves translation to GraphBuilder * ch * fix issue * ir * urls * check
1 parent 689cc6f commit 664e084

File tree

6 files changed

+127
-17
lines changed

6 files changed

+127
-17
lines changed

.github/workflows/check-urls.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,6 @@ jobs:
4242
print_all: false
4343
timeout: 2
4444
retry_count# : 2
45-
exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document
46-
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/
45+
exclude_urls: https://hal.archives-ouvertes.fr/hal-00990252/document,https://github.com/onnx/tensorflow-onnx
46+
exclude_patterns: https://www.data.gouv.fr/fr/datasets/r/e3d83ab3-dc52-4c99-abaf-8a38050cc68c,https://dev.azure.com/,https://github.com/onnx/tensorflow-onnx
4747
# force_pass : true

CHANGELOGS.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.3.1
5+
+++++
6+
7+
* :pr:`95`: improves translation to GraphBuilder
8+
49
0.3.0
510
+++++
611

_unittests/ut_translate_api/test_translate_builder.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from onnx_array_api.ext_test_case import ExtTestCase
99
from onnx_array_api.light_api import start
1010
from onnx_array_api.graph_api import GraphBuilder
11-
from onnx_array_api.translate_api import translate
11+
from onnx_array_api.translate_api import translate, Translater
12+
from onnx_array_api.translate_api.builder_emitter import BuilderEmitter
1213

1314

1415
OPSET_API = min(19, onnx_opset_version() - 1)
@@ -19,7 +20,7 @@ def setUp(self):
1920
self.maxDiff = None
2021

2122
def test_exp(self):
22-
onx = start(opset=19).vin("X").Exp().rename("Y").vout().to_onnx()
23+
onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
2324
self.assertIsInstance(onx, ModelProto)
2425
self.assertIn("Exp", str(onx))
2526
ref = ReferenceEvaluator(onx)
@@ -38,7 +39,7 @@ def light_api(
3839
op.Identity(Y, outputs=["Y"])
3940
return Y
4041
41-
g = GraphBuilder({'': 19})
42+
g = GraphBuilder({'': 19}, ir_version=10)
4243
g.make_tensor_input("X", TensorProto.FLOAT, ())
4344
light_api(g.op, "X")
4445
g.make_tensor_output("Y", TensorProto.FLOAT, ())
@@ -68,7 +69,7 @@ def light_api(
6869

6970
def test_zdoc(self):
7071
onx = (
71-
start(opset=19)
72+
start(opset=19, ir_version=10)
7273
.vin("X")
7374
.reshape((-1, 1))
7475
.Transpose(perm=[1, 0])
@@ -89,7 +90,7 @@ def light_api(
8990
op.Identity(Y, outputs=["Y"])
9091
return Y
9192
92-
g = GraphBuilder({'': 19})
93+
g = GraphBuilder({'': 19}, ir_version=10)
9394
g.make_tensor_input("X", TensorProto.FLOAT, ())
9495
light_api(g.op, "X")
9596
g.make_tensor_output("Y", TensorProto.FLOAT, ())
@@ -117,6 +118,62 @@ def light_api(
117118
self.assertNotEmpty(model)
118119
check_model(model)
119120

121+
def test_exp_f(self):
122+
onx = start(opset=19, ir_version=10).vin("X").Exp().rename("Y").vout().to_onnx()
123+
self.assertIsInstance(onx, ModelProto)
124+
self.assertIn("Exp", str(onx))
125+
ref = ReferenceEvaluator(onx)
126+
a = np.arange(10).astype(np.float32)
127+
got = ref.run(None, {"X": a})[0]
128+
self.assertEqualArray(np.exp(a), got)
129+
130+
tr = Translater(onx, emitter=BuilderEmitter("mm"))
131+
code = tr.export(as_str=True)
132+
133+
expected = dedent(
134+
"""
135+
def light_api(
136+
op: "GraphBuilder",
137+
X: "FLOAT[]",
138+
):
139+
Y = op.Exp(X)
140+
op.Identity(Y, outputs=["Y"])
141+
return Y
142+
143+
144+
def mm() -> "ModelProto":
145+
g = GraphBuilder({'': 19}, ir_version=10)
146+
g.make_tensor_input("X", TensorProto.FLOAT, ())
147+
light_api(g.op, "X")
148+
g.make_tensor_output("Y", TensorProto.FLOAT, ())
149+
model = g.to_onnx()
150+
return model
151+
152+
153+
model = mm()
154+
"""
155+
).strip("\n")
156+
self.assertEqual(expected, code.strip("\n"))
157+
158+
def light_api(
159+
op: "GraphBuilder",
160+
X: "FLOAT[]", # noqa: F722
161+
):
162+
Y = op.Exp(X)
163+
op.Identity(Y, outputs=["Y"])
164+
return Y
165+
166+
g2 = GraphBuilder({"": 19})
167+
g2.make_tensor_input("X", TensorProto.FLOAT, ("A",))
168+
light_api(g2.op, "X")
169+
g2.make_tensor_output("Y", TensorProto.FLOAT, ("A",))
170+
onx2 = g2.to_onnx()
171+
172+
ref = ReferenceEvaluator(onx2)
173+
a = np.arange(10).astype(np.float32)
174+
got = ref.run(None, {"X": a})[0]
175+
self.assertEqualArray(np.exp(a), got)
176+
120177

121178
if __name__ == "__main__":
122179
unittest.main(verbosity=2)

onnx_array_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
APIs to create ONNX Graphs.
33
"""
44

5-
__version__ = "0.3.0"
5+
__version__ = "0.3.1"
66
__author__ = "Xavier Dupré"

onnx_array_api/translate_api/builder_emitter.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@
44
from .base_emitter import BaseEmitter
55

66
_types = {
7+
TensorProto.DOUBLE: "DOUBLE",
78
TensorProto.FLOAT: "FLOAT",
89
TensorProto.FLOAT16: "FLOAT16",
910
TensorProto.INT64: "INT64",
1011
TensorProto.INT32: "INT32",
12+
TensorProto.INT16: "INT16",
13+
TensorProto.UINT64: "UINT64",
14+
TensorProto.UINT32: "UINT32",
15+
TensorProto.UINT16: "UINT16",
16+
TensorProto.STRING: "STRING",
17+
TensorProto.BOOL: "BOOL",
1118
}
1219

1320

@@ -20,6 +27,10 @@ class BuilderEmitter(BaseEmitter):
2027
Converts event into proper code.
2128
"""
2229

30+
def __init__(self, make_model_function: str = ""):
31+
super().__init__()
32+
self.make_model_function = make_model_function
33+
2334
def join(self, rows: List[str], single_line: bool = False) -> str:
2435
"Join the rows"
2536
assert (
@@ -29,6 +40,7 @@ def join(self, rows: List[str], single_line: bool = False) -> str:
2940

3041
def _emit_start(self, **kwargs: Dict[str, Any]) -> List[str]:
3142
self.opsets = kwargs.get("opsets", {})
43+
self.ir_version = kwargs.get("ir_version", None)
3244
return []
3345

3446
def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
@@ -43,12 +55,27 @@ def _emit_to_onnx_model(self, **kwargs: Dict[str, Any]) -> List[str]:
4355
)
4456
rows = [
4557
"",
46-
f"g = GraphBuilder({self.opsets})",
58+
(
59+
f"g = GraphBuilder({self.opsets}, ir_version={self.ir_version})"
60+
if self.ir_version
61+
else f"GraphBuilder({self.opsets})"
62+
),
4763
*inputs,
4864
f"{self.name}({inps})",
4965
*outputs,
5066
"model = g.to_onnx()",
5167
]
68+
if self.make_model_function:
69+
rows = [
70+
"",
71+
"",
72+
f'def {self.make_model_function}() -> "ModelProto":',
73+
*[" " + _ for _ in rows[1:]],
74+
" return model",
75+
"",
76+
"",
77+
f"model = {self.make_model_function}()",
78+
]
5279
return rows
5380

5481
def _emit_begin_graph(self, **kwargs: Dict[str, Any]) -> List[str]:
@@ -78,13 +105,16 @@ def _emit_input(self, **kwargs: Dict[str, Any]) -> List[str]:
78105
name = kwargs["name"]
79106
itype = kwargs.get("elem_type", 0)
80107
shape = kwargs.get("shape", None)
108+
name = self._clean_result_name(name)
81109
if itype == 0:
82-
inp = "X"
110+
inp = name or "X"
83111
else:
84112
if shape is None:
85-
inp = f'X: "{_itype_to_string(itype)}"'
113+
inp = f'{name}: "{_itype_to_string(itype)}"'
86114
else:
87-
inp = f'X: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"'
115+
inp = (
116+
f'{name}: "{_itype_to_string(itype)}[{", ".join(map(str, shape))}]"'
117+
)
88118
self.inputs_full.append(inp)
89119
self.inputs.append(name)
90120
self.inputs_full_.append((name, _itype_to_string(itype), shape))
@@ -113,6 +143,7 @@ def _emit_end_return(self, **kwargs: Dict[str, Any]) -> List[str]:
113143

114144
def _emit_output(self, **kwargs: Dict[str, Any]) -> List[str]:
115145
name = kwargs["name"]
146+
name = self._clean_result_name(name)
116147
itype = kwargs.get("elem_type", 0)
117148
shape = kwargs.get("shape", None)
118149
self.outputs.append(name)
@@ -126,6 +157,8 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
126157
if kwargs.get("domain", "") != "":
127158
domain = kwargs["domain"]
128159
op_type = f"{domain}.{op_type}"
160+
else:
161+
domain = ""
129162
atts = kwargs.get("atts", {})
130163
args = []
131164
for k, v in atts.items():
@@ -134,11 +167,22 @@ def _emit_node(self, **kwargs: Dict[str, Any]) -> List[str]:
134167
raise NotImplementedError("Graph attribute not supported yet.")
135168
args.append(f"{k}={vatt}")
136169

137-
outs = ", ".join(outputs)
138-
inps = ", ".join(inputs)
170+
outs = ", ".join(map(self._clean_result_name, outputs))
171+
inps = ", ".join(map(self._clean_result_name, inputs))
172+
op_type = self._emit_node_type(op_type, domain)
173+
sdomain = "" if not domain else f", domain={domain!r}"
139174
if args:
140175
sargs = ", ".join(args)
141-
row = f" {outs} = op.{op_type}({inps}, {sargs})"
176+
if inps:
177+
row = f" {outs} = op.{op_type}({inps}, {sargs}{sdomain})"
178+
else:
179+
row = f" {outs} = op.{op_type}({sargs}{sdomain})"
142180
else:
143-
row = f" {outs} = op.{op_type}({inps})"
181+
row = f" {outs} = op.{op_type}({inps}{sdomain})"
144182
return [row]
183+
184+
def _clean_result_name(self, name):
185+
return name
186+
187+
def _emit_node_type(self, op_type, domain):
188+
return op_type

onnx_array_api/translate_api/translate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
3535
last_event = None
3636
if isinstance(self.proto_, ModelProto):
3737
opsets = {d.domain: d.version for d in self.proto_.opset_import}
38-
rows.extend(self.emitter(EventType.START, opsets=opsets))
38+
rows.extend(
39+
self.emitter(
40+
EventType.START, opsets=opsets, ir_version=self.proto_.ir_version
41+
)
42+
)
3943
inputs = self.proto_.graph.input
4044
outputs = self.proto_.graph.output
4145
nodes = self.proto_.graph.node

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy