Skip to content

Commit 0c2a92d

Browse files
committed
fix unit test
1 parent af88e8d commit 0c2a92d

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

_unittests/ut_translate_api/test_translate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,9 @@ def test_transpose(self):
8080
"""
8181
(
8282
start(opset=19)
83+
.vin('X', elem_type=TensorProto.FLOAT)
8384
.cst(np.array([-1, 1], dtype=np.int64))
8485
.rename('r')
85-
.vin('X', elem_type=TensorProto.FLOAT)
8686
.bring('X', 'r')
8787
.Reshape()
8888
.rename('r0_0')
@@ -166,9 +166,9 @@ def test_export_if(self):
166166
f"""
167167
(
168168
start(opset=19)
169+
.vin('X', elem_type=TensorProto.FLOAT)
169170
.cst(np.array([0.0], dtype=np.float32))
170171
.rename('r')
171-
.vin('X', elem_type=TensorProto.FLOAT)
172172
.bring('X')
173173
.ReduceSum(keepdims=1, noop_with_empty_axes=0)
174174
.rename('Xs')
@@ -202,9 +202,9 @@ def test_aionnxml(self):
202202
"""
203203
(
204204
start(opset=19, opsets={'ai.onnx.ml': 3})
205+
.vin('X', elem_type=TensorProto.FLOAT)
205206
.cst(np.array([-1, 1], dtype=np.int64))
206207
.rename('r')
207-
.vin('X', elem_type=TensorProto.FLOAT)
208208
.bring('X', 'r')
209209
.Reshape()
210210
.rename('USE')

_unittests/ut_translate_api/test_translate_classic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ def test_transpose(self):
138138
initializers = []
139139
sparse_initializers = []
140140
functions = []
141+
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
141142
initializers.append(
142143
from_array(
143144
np.array([-1, 1], dtype=np.int64),
144145
name='r'
145146
)
146147
)
147-
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
148148
nodes.append(
149149
make_node_extended(
150150
'Reshape',
@@ -278,13 +278,13 @@ def test_aionnxml(self):
278278
initializers = []
279279
sparse_initializers = []
280280
functions = []
281+
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
281282
initializers.append(
282283
from_array(
283284
np.array([-1, 1], dtype=np.int64),
284285
name='r'
285286
)
286287
)
287-
inputs.append(make_tensor_value_info('X', TensorProto.FLOAT, shape=[]))
288288
nodes.append(
289289
make_node_extended(
290290
'Reshape',

onnx_array_api/translate_api/translate.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def export(self, as_str, single_line: bool = False) -> Union[str, List[str]]:
7575
domain=self.proto_.domain,
7676
)
7777
)
78+
elif isinstance(self.proto_, GraphProto):
79+
rows.extend(self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.name))
7880
else:
7981
rows.extend(
8082
self.emitter(EventType.BEGIN_GRAPH, name=self.proto_.graph.name)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy