Skip to content

Commit a54de21

Browse files
authored
Better support for ir_version (#82)
* fixes for ir_version * fix ut * fix ut
1 parent 492b6d4 commit a54de21

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

_unittests/ut_light_api/test_backend_export.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
242242

243243
# The following tests are too slow with the reference implementation (Conv).
244244
backend_test.exclude(
245-
"(FLOAT8|BFLOAT16|_opt_|_3d_|_momentum_|_4d_"
245+
"(FLOAT8|BFLOAT16|INT4|_opt_|_3d_|_momentum_|_4d_|int4"
246246
"|test_adagrad"
247247
"|test_adam"
248248
"|test_ai_onnx_ml_"
@@ -270,6 +270,8 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
270270
"|test_squeezenet"
271271
"|test_vgg19"
272272
"|test_zfnet512"
273+
"|test_range_float_type_positive_delta_expanded"
274+
"|test_range_int32_type_negative_delta_expanded"
273275
")"
274276
)
275277

_unittests/ut_reference/test_backend_extended_reference_evaluator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
149149
"|test_scan_sum)"
150150
)
151151

152-
if onnx_opset_version() < 21:
152+
if onnx_opset_version() < 200:
153153
# The following tests are using types not supported by NumPy.
154154
# They could be if method to_array is extended to support custom
155155
# types the same as the reference implementation does
@@ -164,8 +164,10 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs):
164164
"|test_cast_no_saturate_"
165165
"|_to_FLOAT8"
166166
"|_FLOAT8"
167+
"|INT4"
167168
"|test_quantizelinear_e4m3fn"
168169
"|test_quantizelinear_e5m2"
170+
"|test_scatter_with"
169171
")"
170172
)
171173

onnx_array_api/graph_api/graph_builder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __init__(
156156
optimization_options: Optional[OptimizationOptions] = None,
157157
args: Optional[List[Any]] = None,
158158
verbose: int = 0,
159+
ir_version: Optional[int] = None,
159160
):
160161
self.optimization_options = optimization_options or OptimizationOptions()
161162
self.as_function = as_function
@@ -170,6 +171,7 @@ def __init__(
170171
if isinstance(target_opset_or_existing_proto, int)
171172
else target_opset_or_existing_proto
172173
)
174+
self.ir_version = ir_version
173175
self.nodes = []
174176
self.initializers_dict = {}
175177
self.inputs = []
@@ -186,6 +188,7 @@ def __init__(
186188
), "input_names must be empty if the input is an existing model."
187189
proto = target_opset_or_existing_proto
188190
self.opsets = {d.domain: d.version for d in proto.opset_import}
191+
self.ir_version = ir_version or target_opset_or_existing_proto.ir_version
189192
self.nodes = list(proto.graph.node)
190193
self.initializers_dict = {i.name: i for i in proto.graph.initializer}
191194
self.initializers_dict.update(
@@ -674,6 +677,8 @@ def to_onnx(
674677
if self.verbose:
675678
print("[GraphBuilder] onh.make_model")
676679
model = oh.make_model(graph, opset_imports=opsets)
680+
if self.ir_version:
681+
model.ir_version = self.ir_version
677682
return model
678683

679684
def _check_order_node(self, ind: int, node: NodeProto, existing: Set[str]):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy