From b11db3c1475bb999dffb49cbff852215ead5dd1e Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 Jan 2025 16:28:14 +0100 Subject: [PATCH 1/3] Export evaluator type in compare_onnx_execution --- LICENSE.txt | 2 +- onnx_array_api/reference/evaluator_yield.py | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/LICENSE.txt b/LICENSE.txt index e027853..1a46a8e 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,4 +1,4 @@ -Copyright (c) 2023-2024, Xavier Dupré +Copyright (c) 2023-2025, Xavier Dupré Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/onnx_array_api/reference/evaluator_yield.py b/onnx_array_api/reference/evaluator_yield.py index 5b77e8b..7d16be3 100644 --- a/onnx_array_api/reference/evaluator_yield.py +++ b/onnx_array_api/reference/evaluator_yield.py @@ -3,6 +3,7 @@ from enum import IntEnum import numpy as np from onnx import ModelProto, TensorProto, ValueInfoProto, load +from onnx.reference import ReferenceEvaluator from onnx.helper import tensor_dtype_to_np_dtype from onnx.shape_inference import infer_shapes from . import to_array_extended @@ -166,9 +167,9 @@ def enumerate_results( Returns: iterator on tuple(result kind, name, value, node.op_type or None) """ - assert isinstance(self.evaluator, ExtendedReferenceEvaluator), ( + assert isinstance(self.evaluator, ReferenceEvaluator), ( f"This implementation only works with " - f"ExtendedReferenceEvaluator not {type(self.evaluator)}" + f"ReferenceEvaluator not {type(self.evaluator)}" ) attributes = {} if output_names is None: @@ -595,6 +596,7 @@ def compare_onnx_execution( raise_exc: bool = True, mode: str = "execute", keep_tensor: bool = False, + cls: Optional[type[ReferenceEvaluator]] = None, ) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]: """ Compares the execution of two onnx models. @@ -611,6 +613,7 @@ def compare_onnx_execution( :param mode: the model should be executed but the function can be executed but the comparison may append on nodes only :param keep_tensor: keeps the tensor in order to compute a precise distance + :param cls: evaluator class to use :return: four results, a sequence of results for the first model and the second model, the alignment between the two, DistanceExecution @@ -634,7 +637,7 @@ def compare_onnx_execution( print(f"[compare_onnx_execution] execute with {len(inputs)} inputs") print("[compare_onnx_execution] execute first model") res1 = list( - YieldEvaluator(model1).enumerate_summarized( + YieldEvaluator(model1, cls=cls).enumerate_summarized( None, feeds1, raise_exc=raise_exc, keep_tensor=keep_tensor ) ) @@ -642,7 +645,7 @@ def compare_onnx_execution( print(f"[compare_onnx_execution] got {len(res1)} results") print("[compare_onnx_execution] execute second model") res2 = list( - YieldEvaluator(model2).enumerate_summarized( + YieldEvaluator(model2, cls=cls).enumerate_summarized( None, feeds2, raise_exc=raise_exc, keep_tensor=keep_tensor ) ) From 5e73e8d8395945c5432c44bdc44f606063794463 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 Jan 2025 17:54:19 +0100 Subject: [PATCH 2/3] doc --- onnx_array_api/reference/evaluator_yield.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/onnx_array_api/reference/evaluator_yield.py b/onnx_array_api/reference/evaluator_yield.py index 7d16be3..82da956 100644 --- a/onnx_array_api/reference/evaluator_yield.py +++ b/onnx_array_api/reference/evaluator_yield.py @@ -139,17 +139,22 @@ class YieldEvaluator: :param onnx_model: model to run :param recursive: dig into subgraph and functions as well + :param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator` """ def __init__( self, onnx_model: ModelProto, recursive: bool = False, - cls=ExtendedReferenceEvaluator, + cls: Optional[type[ExtendedReferenceEvaluator]] = None, ): assert not recursive, "recursive=True is not yet implemented" self.onnx_model = onnx_model - self.evaluator = cls(onnx_model) if cls is not None else None + self.evaluator = ( + cls(onnx_model) + if cls is not None + else ExtendedReferenceEvaluator(onnx_model) + ) def enumerate_results( self, From f3c75299a167022e26b1a390961ec5034e9d2779 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 6 Jan 2025 17:55:06 +0100 Subject: [PATCH 3/3] doc --- onnx_array_api/reference/evaluator_yield.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnx_array_api/reference/evaluator_yield.py b/onnx_array_api/reference/evaluator_yield.py index 82da956..6ae005c 100644 --- a/onnx_array_api/reference/evaluator_yield.py +++ b/onnx_array_api/reference/evaluator_yield.py @@ -139,7 +139,8 @@ class YieldEvaluator: :param onnx_model: model to run :param recursive: dig into subgraph and functions as well - :param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator` + :param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator + ` """ def __init__( pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy