Skip to content

Commit 4c12efd

Browse files
committed
2 parents d1aff97 + 4cf9dcc commit 4c12efd

File tree

4 files changed

+255
-38
lines changed

4 files changed

+255
-38
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`76`: add a mode to compare models without execution
78
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
89
* :pr:`71`: adds tools to compare two onnx graphs
910
* :pr:`61`: adds function to plot onnx model as graphs

_unittests/ut_reference/test_evaluator_yield.py

Lines changed: 71 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
import numpy as np
33
from onnx import TensorProto
4+
from onnx.checker import check_model
45
from onnx.helper import (
56
make_function,
67
make_graph,
@@ -9,6 +10,7 @@
910
make_opsetid,
1011
make_tensor_value_info,
1112
)
13+
from onnx.numpy_helper import from_array
1214
from onnx.parser import parse_model
1315
from onnx_array_api.ext_test_case import ExtTestCase
1416
from onnx_array_api.reference import (
@@ -422,13 +424,13 @@ def test_distance_sequence_str(self):
422424
text = dc.to_str(s1, s2, align)
423425
self.assertIn("OUTPUT", text)
424426
expected = """
425-
001=|INPUTfloat322x2ABCDA|INPUTfloat322x2ABCDA
426-
002=|INPUTfloat322x2ABCDB|INPUTfloat322x2ABCDB
427-
003~|INPUTfloat322x3ABCDX|INPUTfloat322x2ABCDX
428-
004-|RESULTfloat322x2CEIOExpH|
429-
005=|RESULTfloat322x2CEIOLinearRegrY1|RESULTfloat322x2CEIOLinearRegrY1
430-
006~|RESULTfloat322x2CEIOAbsY|RESULTfloat322x3CEIPAbsZ
431-
007~|OUTPUTfloat322x2CEIOY|OUTPUTfloat322x2CEIPY
427+
001=|INPUTfloat322:2x2ABCDA|INPUTfloat322:2x2ABCDA
428+
002=|INPUTfloat322:2x2ABCDB|INPUTfloat322:2x2ABCDB
429+
003~|INPUTfloat322:2x3ABCDX|INPUTfloat322:2x2ABCDX
430+
004-|RESULTfloat322:2x2CEIOExpH|
431+
005=|RESULTfloat322:2x2CEIOLinearRegressioY1|RESULTfloat322:2x2CEIOLinearRegressioY1
432+
006~|RESULTfloat322:2x2CEIOAbsY|RESULTfloat322:2x3CEIPAbsZ
433+
007~|OUTPUTfloat322:2x2CEIOY|OUTPUTfloat322:2x2CEIPY
432434
""".replace(
433435
" ", ""
434436
).strip(
@@ -460,6 +462,68 @@ def test_compare_execution(self):
460462
self.assertIn("CAAA Constant", text)
461463
self.assertEqual(len(align), 5)
462464

465+
def test_no_execution(self):
466+
model = make_model(
467+
make_graph(
468+
[
469+
make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
470+
make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
471+
make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
472+
make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
473+
make_node("Cast", ["xm2c"], ["xm2"], to=1),
474+
make_node("MatMul", ["xm1", "xm2"], ["xm"]),
475+
make_node("Reshape", ["xm", "shape3"], ["Z"]),
476+
],
477+
"dummy",
478+
[
479+
make_tensor_value_info("X", TensorProto.FLOAT, [32, 128]),
480+
make_tensor_value_info("Y", TensorProto.FLOAT, [3, 5, 128, 64]),
481+
],
482+
[make_tensor_value_info("Z", TensorProto.FLOAT, [3, 5, 32, "N"])],
483+
[
484+
from_array(np.array([0], dtype=np.int64), name="zero"),
485+
from_array(np.array([1], dtype=np.int64), name="un"),
486+
from_array(np.array([1, 32, 128], dtype=np.int64), name="shape1"),
487+
from_array(np.array([15, 128, 64], dtype=np.int64), name="shape2"),
488+
from_array(np.array([3, 5, 32, 64], dtype=np.int64), name="shape3"),
489+
],
490+
)
491+
)
492+
check_model(model)
493+
res1, res2, align, dc = compare_onnx_execution(model, model, mode="nodes")
494+
text = dc.to_str(res1, res2, align)
495+
self.assertIn("012 = | NODE", text)
496+
497+
model2 = make_model(
498+
make_graph(
499+
[
500+
make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
501+
make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
502+
make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
503+
make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
504+
make_node("MatMul", ["xm1", "xm2c"], ["xm"]),
505+
make_node("Reshape", ["xm", "shape3"], ["Z"]),
506+
],
507+
"dummy",
508+
[
509+
make_tensor_value_info("X", TensorProto.FLOAT, [32, 128]),
510+
make_tensor_value_info("Y", TensorProto.FLOAT, [3, 5, 128, 64]),
511+
],
512+
[make_tensor_value_info("Z", TensorProto.FLOAT, [3, 5, 32, "N"])],
513+
[
514+
from_array(np.array([0], dtype=np.int64), name="zero"),
515+
from_array(np.array([1], dtype=np.int64), name="un"),
516+
from_array(np.array([1, 32, 128], dtype=np.int64), name="shape1"),
517+
from_array(np.array([15, 128, 64], dtype=np.int64), name="shape2"),
518+
from_array(np.array([3, 5, 32, 64], dtype=np.int64), name="shape3"),
519+
],
520+
)
521+
)
522+
check_model(model2)
523+
res1, res2, align, dc = compare_onnx_execution(model, model2, mode="nodes")
524+
text = dc.to_str(res1, res2, align)
525+
self.assertIn("012 = | NODE", text)
526+
463527

464528
if __name__ == "__main__":
465529
unittest.main(verbosity=2)

onnx_array_api/_command_lines_parser.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_main_parser() -> ArgumentParser:
2020
Selects a command.
2121
2222
'translate' exports an onnx graph into a piece of code replicating it,
23-
'compares' compares the execution of two onnx models
23+
'compare' compares the execution of two onnx models
2424
"""
2525
),
2626
)
@@ -90,6 +90,13 @@ def get_parser_compare() -> ArgumentParser:
9090
required=True,
9191
help="second onnx model",
9292
)
93+
parser.add_argument(
94+
"-m",
95+
"--mode",
96+
choices=["execute", "nodes"],
97+
default="execute",
98+
help="compare the execution ('execute') or the nodes only ('nodes')",
99+
)
93100
parser.add_argument(
94101
"-v",
95102
"--verbose",
@@ -112,8 +119,10 @@ def _cmd_compare(argv: List[Any]):
112119
args = parser.parse_args(argv[1:])
113120
onx1 = onnx.load(args.model1)
114121
onx2 = onnx.load(args.model2)
115-
res1, res2, align, dc = compare_onnx_execution(onx1, onx2, verbose=args.verbose)
116-
text = dc.to_str(res1, res2, align, column_size=args.column_size)
122+
res1, res2, align, dc = compare_onnx_execution(
123+
onx1, onx2, verbose=args.verbose, mode=args.mode
124+
)
125+
text = dc.to_str(res1, res2, align, column_size=int(args.column_size))
117126
print(text)
118127

119128

@@ -127,7 +136,7 @@ def main(argv: Optional[List[Any]] = None):
127136
parser = get_main_parser()
128137
parser.parse_args(argv)
129138
else:
130-
parsers = dict(translate=get_parser_translate)
139+
parsers = dict(translate=get_parser_translate, compare=get_parser_compare)
131140
cmd = argv[0]
132141
if cmd not in parsers:
133142
raise ValueError(

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy