Skip to content

Commit 271c29d

Browse files
committed
Add a mode to compare model without execution
1 parent d1aff97 commit 271c29d

File tree

3 files changed

+219
-31
lines changed

3 files changed

+219
-31
lines changed

_unittests/ut_reference/test_evaluator_yield.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
import numpy as np
33
from onnx import TensorProto
4+
from onnx.checker import check_model
45
from onnx.helper import (
56
make_function,
67
make_graph,
@@ -9,6 +10,7 @@
910
make_opsetid,
1011
make_tensor_value_info,
1112
)
13+
from onnx.numpy_helper import from_array
1214
from onnx.parser import parse_model
1315
from onnx_array_api.ext_test_case import ExtTestCase
1416
from onnx_array_api.reference import (
@@ -426,7 +428,7 @@ def test_distance_sequence_str(self):
426428
002=|INPUTfloat322x2ABCDB|INPUTfloat322x2ABCDB
427429
003~|INPUTfloat322x3ABCDX|INPUTfloat322x2ABCDX
428430
004-|RESULTfloat322x2CEIOExpH|
429-
005=|RESULTfloat322x2CEIOLinearRegrY1|RESULTfloat322x2CEIOLinearRegrY1
431+
005=|RESULTfloat322x2CEIOLinearRegresY1|RESULTfloat322x2CEIOLinearRegresY1
430432
006~|RESULTfloat322x2CEIOAbsY|RESULTfloat322x3CEIPAbsZ
431433
007~|OUTPUTfloat322x2CEIOY|OUTPUTfloat322x2CEIPY
432434
""".replace(
@@ -460,6 +462,68 @@ def test_compare_execution(self):
460462
self.assertIn("CAAA Constant", text)
461463
self.assertEqual(len(align), 5)
462464

465+
def test_no_execution(self):
466+
model = make_model(
467+
make_graph(
468+
[
469+
make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
470+
make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
471+
make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
472+
make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
473+
make_node("Cast", ["xm2c"], ["xm2"], to=1),
474+
make_node("MatMul", ["xm1", "xm2"], ["xm"]),
475+
make_node("Reshape", ["xm", "shape3"], ["Z"]),
476+
],
477+
"dummy",
478+
[
479+
make_tensor_value_info("X", TensorProto.FLOAT, [32, 128]),
480+
make_tensor_value_info("Y", TensorProto.FLOAT, [3, 5, 128, 64]),
481+
],
482+
[make_tensor_value_info("Z", TensorProto.FLOAT, [3, 5, 32, "N"])],
483+
[
484+
from_array(np.array([0], dtype=np.int64), name="zero"),
485+
from_array(np.array([1], dtype=np.int64), name="un"),
486+
from_array(np.array([1, 32, 128], dtype=np.int64), name="shape1"),
487+
from_array(np.array([15, 128, 64], dtype=np.int64), name="shape2"),
488+
from_array(np.array([3, 5, 32, 64], dtype=np.int64), name="shape3"),
489+
],
490+
)
491+
)
492+
check_model(model)
493+
res1, res2, align, dc = compare_onnx_execution(model, model, mode="nodes")
494+
text = dc.to_str(res1, res2, align)
495+
self.assertIn("012 = | NODE", text)
496+
497+
model2 = make_model(
498+
make_graph(
499+
[
500+
make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
501+
make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
502+
make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
503+
make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
504+
make_node("MatMul", ["xm1", "xm2c"], ["xm"]),
505+
make_node("Reshape", ["xm", "shape3"], ["Z"]),
506+
],
507+
"dummy",
508+
[
509+
make_tensor_value_info("X", TensorProto.FLOAT, [32, 128]),
510+
make_tensor_value_info("Y", TensorProto.FLOAT, [3, 5, 128, 64]),
511+
],
512+
[make_tensor_value_info("Z", TensorProto.FLOAT, [3, 5, 32, "N"])],
513+
[
514+
from_array(np.array([0], dtype=np.int64), name="zero"),
515+
from_array(np.array([1], dtype=np.int64), name="un"),
516+
from_array(np.array([1, 32, 128], dtype=np.int64), name="shape1"),
517+
from_array(np.array([15, 128, 64], dtype=np.int64), name="shape2"),
518+
from_array(np.array([3, 5, 32, 64], dtype=np.int64), name="shape3"),
519+
],
520+
)
521+
)
522+
check_model(model2)
523+
res1, res2, align, dc = compare_onnx_execution(model, model2, mode="nodes")
524+
text = dc.to_str(res1, res2, align)
525+
self.assertIn("012 = | NODE", text)
526+
463527

464528
if __name__ == "__main__":
465529
unittest.main(verbosity=2)

onnx_array_api/_command_lines_parser.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_main_parser() -> ArgumentParser:
2020
Selects a command.
2121
2222
'translate' exports an onnx graph into a piece of code replicating it,
23-
'compares' compares the execution of two onnx models
23+
'compare' compares the execution of two onnx models
2424
"""
2525
),
2626
)
@@ -90,6 +90,13 @@ def get_parser_compare() -> ArgumentParser:
9090
required=True,
9191
help="second onnx model",
9292
)
93+
parser.add_argument(
94+
"-m",
95+
"--mode",
96+
choices=["execute", "nodes"],
97+
default="execute",
98+
help="compare the execution ('execute') or the nodes only ('nodes')",
99+
)
93100
parser.add_argument(
94101
"-v",
95102
"--verbose",
@@ -112,7 +119,9 @@ def _cmd_compare(argv: List[Any]):
112119
args = parser.parse_args(argv[1:])
113120
onx1 = onnx.load(args.model1)
114121
onx2 = onnx.load(args.model2)
115-
res1, res2, align, dc = compare_onnx_execution(onx1, onx2, verbose=args.verbose)
122+
res1, res2, align, dc = compare_onnx_execution(
123+
onx1, onx2, verbose=args.verbose, mode=args.mode
124+
)
116125
text = dc.to_str(res1, res2, align, column_size=args.column_size)
117126
print(text)
118127

@@ -127,7 +136,7 @@ def main(argv: Optional[List[Any]] = None):
127136
parser = get_main_parser()
128137
parser.parse_args(argv)
129138
else:
130-
parsers = dict(translate=get_parser_translate)
139+
parsers = dict(translate=get_parser_translate, compare=get_parser_compare)
131140
cmd = argv[0]
132141
if cmd not in parsers:
133142
raise ValueError(

onnx_array_api/reference/evaluator_yield.py

Lines changed: 142 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
from typing import Any, Dict, List, Iterator, Optional, Tuple, Union
33
from enum import IntEnum
44
import numpy as np
5-
from onnx import ModelProto, TensorProto, ValueInfoProto
5+
from onnx import ModelProto, TensorProto, ValueInfoProto, load
6+
from onnx.helper import tensor_dtype_to_np_dtype
7+
from onnx.shape_inference import infer_shapes
68
from .evaluator import ExtendedReferenceEvaluator
79

810

@@ -20,6 +22,7 @@ class ResultType(IntEnum):
2022
SPARSE_INITIALIZER = 4
2123
INPUT = 8
2224
OUTPUT = 16
25+
NODE = 32
2326

2427
def __repr__(self):
2528
return f"{self.__class__.__name__}.{self._name_}"
@@ -57,12 +60,13 @@ def __getitem__(self, i: int) -> Any:
5760
raise IndexError(f"i={i} out of boundary")
5861

5962
def __str__(self):
63+
dtype = self.dtype if self.dtype != 0 else ""
6064
els = [
6165
_align(self.kind._name_, 6),
62-
_align(str(self.dtype).replace("dtype(", "").replace(")", ""), 8),
63-
_align("x".join(map(str, self.shape)), 15),
66+
_align(str(dtype).replace("dtype(", "").replace(")", ""), 8),
67+
_align("x".join("" if self.shape is None else map(str, self.shape)), 15),
6468
self.summary,
65-
_align(self.op_type or "", 10),
69+
_align(self.op_type or "", 12),
6670
self.name or "",
6771
]
6872
return " ".join(els)
@@ -270,6 +274,22 @@ def _cost_type(self, t1: "np.dtype", t2: "np.dtype") -> float:
270274
return 1
271275

272276
def _cost_shape(self, s1: Tuple[int, ...], s2: Tuple[int, ...]) -> float:
277+
if s1 is None or s2 is None:
278+
return self.rank_cost
279+
if any(map(lambda s: isinstance(s, str), s1)) or any(
280+
map(lambda s: isinstance(s, str), s2)
281+
):
282+
# dynamic shapes
283+
if len(s1) != len(s2):
284+
return self.rank_cost
285+
d = 0
286+
for i, j in zip(s1, s2):
287+
if isinstance(i, int) and isinstance(j, int):
288+
d += abs(i - j)
289+
elif i != j:
290+
d += self.rank_cost / 2
291+
return d
292+
273293
d = abs(np.prod(s1) - np.prod(s2))
274294
if len(s1) != len(s2):
275295
return self.rank_cost + d
@@ -424,12 +444,85 @@ def generate_inputs(model: ModelProto) -> List[np.ndarray]:
424444
return inputs
425445

426446

447+
def _update_shape_types_with_proto(
448+
proto: ModelProto,
449+
) -> Dict[str, Tuple[int, Tuple[Union[int, str], ...]]]:
450+
"""
451+
Retrieves the shapes and types for a model.
452+
"""
453+
assert isinstance(proto, ModelProto), f"Unexpected type {type(proto)} for proto"
454+
res = {}
455+
456+
for val in proto.graph.input:
457+
itype = val.type.tensor_type.elem_type
458+
shape = tuple(
459+
d.dim_param if d.dim_param else d.dim_value
460+
for d in val.type.tensor_type.shape.dim
461+
)
462+
res[val.name] = [itype, shape]
463+
464+
for val in proto.graph.output:
465+
itype = val.type.tensor_type.elem_type
466+
shape = tuple(
467+
d.dim_param if d.dim_param else d.dim_value
468+
for d in val.type.tensor_type.shape.dim
469+
)
470+
res[val.name] = [itype, shape]
471+
472+
for val in proto.graph.initializer:
473+
itype = val.data_type
474+
shape = tuple(d for d in val.dims)
475+
res[val.name] = [itype, shape]
476+
477+
new_proto = infer_shapes(proto)
478+
for val in new_proto.graph.value_info:
479+
itype = val.type.tensor_type.elem_type
480+
shape = tuple(
481+
d.dim_param if d.dim_param else d.dim_value
482+
for d in val.type.tensor_type.shape.dim
483+
)
484+
res[val.name] = [itype, shape]
485+
486+
return res
487+
488+
489+
def _enumerate_result_no_execution(model: ModelProto) -> Iterator[ResultType]:
490+
"""
491+
Produces a list of results based on a model in order to
492+
trigger the edit distance comparison.
493+
"""
494+
type_shape = _update_shape_types_with_proto(model)
495+
for i in model.graph.initializer:
496+
itype, shape = type_shape.get(i.name, (0, None))
497+
dtype = tensor_dtype_to_np_dtype(itype)
498+
yield ResultExecution(
499+
ResultType.INITIALIZER, dtype, shape, "????", "INIT", i.name
500+
)
501+
for i in model.graph.input:
502+
itype, shape = type_shape.get(i.name, (0, None))
503+
dtype = tensor_dtype_to_np_dtype(itype)
504+
yield ResultExecution(ResultType.INPUT, dtype, shape, "????", "INPUT", i.name)
505+
for node in model.graph.node:
506+
yield ResultExecution(ResultType.NODE, 0, None, "????", node.op_type, node.name)
507+
for o in node.output:
508+
itype, shape = type_shape.get(o, (0, None))
509+
dtype = tensor_dtype_to_np_dtype(itype)
510+
yield ResultExecution(
511+
ResultType.RESULT, dtype, shape, "????", node.op_type, o
512+
)
513+
for i in model.graph.output:
514+
itype, shape = type_shape.get(i.name, (0, None))
515+
dtype = tensor_dtype_to_np_dtype(itype)
516+
yield ResultExecution(ResultType.OUTPUT, dtype, shape, "????", "OUTPUT", i.name)
517+
518+
427519
def compare_onnx_execution(
428520
model1: ModelProto,
429521
model2: ModelProto,
430522
inputs: Optional[Union[List[Any], Tuple[Dict[str, Any]]]] = None,
431523
verbose: int = 0,
432524
raise_exc: bool = True,
525+
mode: str = "execute",
433526
) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]:
434527
"""
435528
Compares the execution of two onnx models.
@@ -443,33 +536,55 @@ def compare_onnx_execution(
443536
the same number of inputs or two dictionaries, one for each model
444537
:param verbose: verbosity
445538
:param raise_exc: raise exception if the execution fails or stop at the error
539+
:param mode: the model should be executed but the function can be executed
540+
but the comparison may append on nodes only
446541
:return: four results, a sequence of results for the first model and the second model,
447542
the alignment between the two, DistanceExecution
448543
"""
449-
if verbose:
450-
print("[compare_onnx_execution] generate inputs")
451-
if inputs is None:
452-
inputs = generate_inputs(model1)
453-
if isinstance(inputs, tuple):
454-
assert len(inputs) == 2, f"Unexpected number {len(inputs)} of inputs."
455-
feeds1, feeds2 = inputs
544+
assert mode in {"execute", "nodes"}, f"Unexpected value for mode={mode!r}."
545+
546+
if mode == "execute":
547+
if inputs is None:
548+
if verbose:
549+
print("[compare_onnx_execution] generate inputs")
550+
inputs = generate_inputs(model1)
551+
if isinstance(inputs, tuple):
552+
assert len(inputs) == 2, f"Unexpected number {len(inputs)} of inputs."
553+
feeds1, feeds2 = inputs
554+
else:
555+
feeds1 = {i.name: v for i, v in zip(model1.graph.input, inputs)}
556+
feeds2 = {i.name: v for i, v in zip(model2.graph.input, inputs)}
557+
assert isinstance(feeds1, dict), f"Unexpected type {type(feeds1)} for inputs"
558+
assert isinstance(feeds2, dict), f"Unexpected type {type(feeds2)} for inputs"
559+
if verbose:
560+
print(f"[compare_onnx_execution] execute with {len(inputs)} inputs")
561+
print("[compare_onnx_execution] execute first model")
562+
res1 = list(
563+
YieldEvaluator(model1).enumerate_summarized(
564+
None, feeds1, raise_exc=raise_exc
565+
)
566+
)
567+
if verbose:
568+
print(f"[compare_onnx_execution] got {len(res1)} results")
569+
print("[compare_onnx_execution] execute second model")
570+
res2 = list(
571+
YieldEvaluator(model2).enumerate_summarized(
572+
None, feeds2, raise_exc=raise_exc
573+
)
574+
)
575+
elif mode == "nodes":
576+
# No execution.
577+
if verbose:
578+
print("[compare_onnx_execution] loading first model")
579+
proto1 = load(model1) if isinstance(model1, str) else model2
580+
if verbose:
581+
print("[compare_onnx_execution] loading first model")
582+
proto2 = load(model2) if isinstance(model2, str) else model1
583+
res1 = list(_enumerate_result_no_execution(proto1))
584+
res2 = list(_enumerate_result_no_execution(proto2))
456585
else:
457-
feeds1 = {i.name: v for i, v in zip(model1.graph.input, inputs)}
458-
feeds2 = {i.name: v for i, v in zip(model2.graph.input, inputs)}
459-
assert isinstance(feeds1, dict), f"Unexpected type {type(feeds1)} for inputs"
460-
assert isinstance(feeds2, dict), f"Unexpected type {type(feeds2)} for inputs"
461-
if verbose:
462-
print(f"[compare_onnx_execution] got {len(inputs)} inputs")
463-
print("[compare_onnx_execution] execute first model")
464-
res1 = list(
465-
YieldEvaluator(model1).enumerate_summarized(None, feeds1, raise_exc=raise_exc)
466-
)
467-
if verbose:
468-
print(f"[compare_onnx_execution] got {len(res1)} results")
469-
print("[compare_onnx_execution] execute second model")
470-
res2 = list(
471-
YieldEvaluator(model2).enumerate_summarized(None, feeds2, raise_exc=raise_exc)
472-
)
586+
return
587+
473588
if verbose:
474589
print(f"[compare_onnx_execution] got {len(res2)} results")
475590
print("[compare_onnx_execution] compute edit distance")

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy