Skip to content

Commit eb106e2

Browse files
authored
Export evaluator type in compare_onnx_execution (#93)
* Export evaluator type in compare_onnx_execution * doc * doc
1 parent 07c3683 commit eb106e2

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

LICENSE.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Copyright (c) 2023-2024, Xavier Dupré
1+
Copyright (c) 2023-2025, Xavier Dupré
22

33
Permission is hereby granted, free of charge, to any person obtaining a copy
44
of this software and associated documentation files (the "Software"), to deal

onnx_array_api/reference/evaluator_yield.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from enum import IntEnum
44
import numpy as np
55
from onnx import ModelProto, TensorProto, ValueInfoProto, load
6+
from onnx.reference import ReferenceEvaluator
67
from onnx.helper import tensor_dtype_to_np_dtype
78
from onnx.shape_inference import infer_shapes
89
from . import to_array_extended
@@ -138,17 +139,23 @@ class YieldEvaluator:
138139
139140
:param onnx_model: model to run
140141
:param recursive: dig into subgraph and functions as well
142+
:param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator
143+
<onnx_array_api.reference.ExtendedReferenceEvaluator>`
141144
"""
142145

143146
def __init__(
144147
self,
145148
onnx_model: ModelProto,
146149
recursive: bool = False,
147-
cls=ExtendedReferenceEvaluator,
150+
cls: Optional[type[ExtendedReferenceEvaluator]] = None,
148151
):
149152
assert not recursive, "recursive=True is not yet implemented"
150153
self.onnx_model = onnx_model
151-
self.evaluator = cls(onnx_model) if cls is not None else None
154+
self.evaluator = (
155+
cls(onnx_model)
156+
if cls is not None
157+
else ExtendedReferenceEvaluator(onnx_model)
158+
)
152159

153160
def enumerate_results(
154161
self,
@@ -166,9 +173,9 @@ def enumerate_results(
166173
Returns:
167174
iterator on tuple(result kind, name, value, node.op_type or None)
168175
"""
169-
assert isinstance(self.evaluator, ExtendedReferenceEvaluator), (
176+
assert isinstance(self.evaluator, ReferenceEvaluator), (
170177
f"This implementation only works with "
171-
f"ExtendedReferenceEvaluator not {type(self.evaluator)}"
178+
f"ReferenceEvaluator not {type(self.evaluator)}"
172179
)
173180
attributes = {}
174181
if output_names is None:
@@ -595,6 +602,7 @@ def compare_onnx_execution(
595602
raise_exc: bool = True,
596603
mode: str = "execute",
597604
keep_tensor: bool = False,
605+
cls: Optional[type[ReferenceEvaluator]] = None,
598606
) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]:
599607
"""
600608
Compares the execution of two onnx models.
@@ -611,6 +619,7 @@ def compare_onnx_execution(
611619
:param mode: the model should be executed but the function can be executed
612620
but the comparison may append on nodes only
613621
:param keep_tensor: keeps the tensor in order to compute a precise distance
622+
:param cls: evaluator class to use
614623
:return: four results, a sequence of results
615624
for the first model and the second model,
616625
the alignment between the two, DistanceExecution
@@ -634,15 +643,15 @@ def compare_onnx_execution(
634643
print(f"[compare_onnx_execution] execute with {len(inputs)} inputs")
635644
print("[compare_onnx_execution] execute first model")
636645
res1 = list(
637-
YieldEvaluator(model1).enumerate_summarized(
646+
YieldEvaluator(model1, cls=cls).enumerate_summarized(
638647
None, feeds1, raise_exc=raise_exc, keep_tensor=keep_tensor
639648
)
640649
)
641650
if verbose:
642651
print(f"[compare_onnx_execution] got {len(res1)} results")
643652
print("[compare_onnx_execution] execute second model")
644653
res2 = list(
645-
YieldEvaluator(model2).enumerate_summarized(
654+
YieldEvaluator(model2, cls=cls).enumerate_summarized(
646655
None, feeds2, raise_exc=raise_exc, keep_tensor=keep_tensor
647656
)
648657
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy