Skip to content

Commit 092dfa2

Browse files
committed
fix order
1 parent 0c2a92d commit 092dfa2

File tree

3 files changed

+39
-13
lines changed

3 files changed

+39
-13
lines changed

_unittests/ut_translate_api/test_translate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def test_transpose(self):
8080
"""
8181
(
8282
start(opset=19)
83-
.vin('X', elem_type=TensorProto.FLOAT)
8483
.cst(np.array([-1, 1], dtype=np.int64))
8584
.rename('r')
85+
.vin('X', elem_type=TensorProto.FLOAT)
8686
.bring('X', 'r')
8787
.Reshape()
8888
.rename('r0_0')
@@ -166,9 +166,9 @@ def test_export_if(self):
166166
f"""
167167
(
168168
start(opset=19)
169-
.vin('X', elem_type=TensorProto.FLOAT)
170169
.cst(np.array([0.0], dtype=np.float32))
171170
.rename('r')
171+
.vin('X', elem_type=TensorProto.FLOAT)
172172
.bring('X')
173173
.ReduceSum(keepdims=1, noop_with_empty_axes=0)
174174
.rename('Xs')
@@ -202,9 +202,9 @@ def test_aionnxml(self):
202202
"""
203203
(
204204
start(opset=19, opsets={'ai.onnx.ml': 3})
205-
.vin('X', elem_type=TensorProto.FLOAT)
206205
.cst(np.array([-1, 1], dtype=np.int64))
207206
.rename('r')
207+
.vin('X', elem_type=TensorProto.FLOAT)
208208
.bring('X', 'r')
209209
.Reshape()
210210
.rename('USE')

_unittests/ut_translate_api/test_translate_builder.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,32 @@ def light_api(
6565
got = ref.run(None, {"X": a})[0]
6666
self.assertEqualArray(np.exp(a), got)
6767

68+
def test_zdoc(self):
69+
onx = (
70+
start()
71+
.vin("X")
72+
.reshape((-1, 1))
73+
.Transpose(perm=[1, 0])
74+
.rename("Y")
75+
.vout()
76+
.to_onnx()
77+
)
78+
code = translate(onx, api="builder")
79+
expected = dedent(
80+
"""
81+
(
82+
start()
83+
.vin("X")
84+
.reshape((-1, 1))
85+
.Transpose(perm=[1, 0])
86+
.rename("Y")
87+
.vout()
88+
.to_onnx()
89+
)"""
90+
).strip("\n")
91+
self.maxDiff = None
92+
self.assertEqual(expected, code)
93+
6894

6995
if __name__ == "__main__":
7096
unittest.main(verbosity=2)

onnx_array_api/translate_api/translate.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,16 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
8282
self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.graph.name)
8383
)
8484

85+
for i in initializers:
86+
rows.extend(
87+
self.emitter(
88+
EventType.INITIALIZER,
89+
name=i.name,
90+
init=i,
91+
value=to_array_extended(i),
92+
)
93+
)
94+
8595
rows.extend(self.emitter(EventType.BEGIN_SIGNATURE))
8696

8797
for i in inputs:
@@ -107,16 +117,6 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
107117

108118
rows.extend(self.emitter(EventType.END_SIGNATURE))
109119

110-
for i in initializers:
111-
rows.extend(
112-
self.emitter(
113-
EventType.INITIALIZER,
114-
name=i.name,
115-
init=i,
116-
value=to_array_extended(i),
117-
)
118-
)
119-
120120
for node in nodes:
121121
atts = self.extract_attributes(node)
122122
rows.extend(

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy